diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 13de5f1f62..bfce541ca7 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,9 +19,6 @@ import random
from abc import ABCMeta
from typing import Any, Optional
-from six import PY2
-from six.moves import builtins
-
from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401
@@ -47,6 +44,9 @@ class SQLBaseStore(metaclass=ABCMeta):
self.db = database
self.rand = random.SystemRandom()
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ pass
+
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
@@ -100,11 +100,6 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
- # psycopg2 on Python 2 returns buffer objects, which we need to cast to
- # bytes to decode
- if PY2 and isinstance(db_content, builtins.buffer):
- db_content = bytes(db_content)
-
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index eb1a7e5002..59f3394b0a 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -90,8 +90,10 @@ class BackgroundUpdater(object):
self._clock = hs.get_clock()
self.db = database
+ # if a background update is currently running, its name.
+ self._current_background_update = None # type: Optional[str]
+
self._background_update_performance = {}
- self._background_update_queue = []
self._background_update_handlers = {}
self._all_done = False
@@ -111,7 +113,7 @@ class BackgroundUpdater(object):
except Exception:
logger.exception("Error doing update")
else:
- if result is None:
+ if result:
logger.info(
"No more background updates to do."
" Unscheduling background update task."
@@ -119,26 +121,25 @@ class BackgroundUpdater(object):
self._all_done = True
return None
- @defer.inlineCallbacks
- def has_completed_background_updates(self):
+ async def has_completed_background_updates(self) -> bool:
"""Check if all the background updates have completed
Returns:
- Deferred[bool]: True if all background updates have completed
+ True if all background updates have completed
"""
# if we've previously determined that there is nothing left to do, that
# is easy
if self._all_done:
return True
- # obviously, if we have things in our queue, we're not done.
- if self._background_update_queue:
+ # obviously, if we are currently processing an update, we're not done.
+ if self._current_background_update:
return False
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
- updates = yield self.db.simple_select_onecol(
+ updates = await self.db.simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
@@ -153,11 +154,10 @@ class BackgroundUpdater(object):
async def has_completed_background_update(self, update_name) -> bool:
"""Check if the given background update has finished running.
"""
-
if self._all_done:
return True
- if update_name in self._background_update_queue:
+ if update_name == self._current_background_update:
return False
update_exists = await self.db.simple_select_one_onecol(
@@ -170,9 +170,7 @@ class BackgroundUpdater(object):
return not update_exists
- async def do_next_background_update(
- self, desired_duration_ms: float
- ) -> Optional[int]:
+ async def do_next_background_update(self, desired_duration_ms: float) -> bool:
"""Does some amount of work on the next queued background update
Returns once some amount of work is done.
@@ -181,33 +179,51 @@ class BackgroundUpdater(object):
desired_duration_ms(float): How long we want to spend
updating.
Returns:
- None if there is no more work to do, otherwise an int
+ True if we have finished running all the background updates, otherwise False
"""
- if not self._background_update_queue:
- updates = await self.db.simple_select_list(
- "background_updates",
- keyvalues=None,
- retcols=("update_name", "depends_on"),
+
+ def get_background_updates_txn(txn):
+ txn.execute(
+ """
+ SELECT update_name, depends_on FROM background_updates
+ ORDER BY ordering, update_name
+ """
)
- in_flight = {update["update_name"] for update in updates}
- for update in updates:
- if update["depends_on"] not in in_flight:
- self._background_update_queue.append(update["update_name"])
+ return self.db.cursor_to_dict(txn)
- if not self._background_update_queue:
- # no work left to do
- return None
+ if not self._current_background_update:
+ all_pending_updates = await self.db.runInteraction(
+ "background_updates", get_background_updates_txn,
+ )
+ if not all_pending_updates:
+ # no work left to do
+ return True
+
+ # find the first update which isn't dependent on another one in the queue.
+ pending = {update["update_name"] for update in all_pending_updates}
+ for upd in all_pending_updates:
+ depends_on = upd["depends_on"]
+ if not depends_on or depends_on not in pending:
+ break
+ logger.info(
+ "Not starting on bg update %s until %s is done",
+ upd["update_name"],
+ depends_on,
+ )
+ else:
+ # if we get to the end of that for loop, there is a problem
+ raise Exception(
+ "Unable to find a background update which doesn't depend on "
+ "another: dependency cycle?"
+ )
- # pop from the front, and add back to the back
- update_name = self._background_update_queue.pop(0)
- self._background_update_queue.append(update_name)
+ self._current_background_update = upd["update_name"]
- res = await self._do_background_update(update_name, desired_duration_ms)
- return res
+ await self._do_background_update(desired_duration_ms)
+ return False
- async def _do_background_update(
- self, update_name: str, desired_duration_ms: float
- ) -> int:
+ async def _do_background_update(self, desired_duration_ms: float) -> int:
+ update_name = self._current_background_update
logger.info("Starting update batch on background update '%s'", update_name)
update_handler = self._background_update_handlers[update_name]
@@ -400,27 +416,6 @@ class BackgroundUpdater(object):
self.register_background_update_handler(update_name, updater)
- def start_background_update(self, update_name, progress):
- """Starts a background update running.
-
- Args:
- update_name: The update to set running.
- progress: The initial state of the progress of the update.
-
- Returns:
- A deferred that completes once the task has been added to the
- queue.
- """
- # Clear the background update queue so that we will pick up the new
- # task on the next iteration of do_background_update.
- self._background_update_queue = []
- progress_json = json.dumps(progress)
-
- return self.db.simple_insert(
- "background_updates",
- {"update_name": update_name, "progress_json": progress_json},
- )
-
def _end_background_update(self, update_name):
"""Removes a completed background update task from the queue.
@@ -429,9 +424,12 @@ class BackgroundUpdater(object):
Returns:
A deferred that completes once the task is removed.
"""
- self._background_update_queue = [
- name for name in self._background_update_queue if name != update_name
- ]
+ if update_name != self._current_background_update:
+ raise Exception(
+ "Cannot end background update %s which isn't currently running"
+ % update_name
+ )
+ self._current_background_update = None
return self.db.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index e1d03429ca..599ee470d4 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -15,6 +15,7 @@
import logging
+from synapse.storage.data_stores.main.events import PersistEventsStore
from synapse.storage.data_stores.state import StateGroupDataStore
from synapse.storage.database import Database, make_conn
from synapse.storage.engines import create_engine
@@ -39,6 +40,7 @@ class DataStores(object):
self.databases = []
self.main = None
self.state = None
+ self.persist_events = None
for database_config in hs.config.database.databases:
db_name = database_config.name
@@ -64,6 +66,13 @@ class DataStores(object):
self.main = main_store_class(database, db_conn, hs)
+ # If we're on a process that can persist events also
+ # instantiate a `PersistEventsStore`
+ if hs.config.worker.writers.events == hs.get_instance_name():
+ self.persist_events = PersistEventsStore(
+ hs, database, self.main
+ )
+
if "state" in database_config.data_stores:
logger.info("Starting 'state' data store")
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index acca079f23..4b4763c701 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -24,15 +24,16 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
- ChainedIdGenerator,
IdGenerator,
+ MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .cache import CacheInvalidationStore
+from .cache import CacheInvalidationWorkerStore
+from .censor_events import CensorEventsStore
from .client_ips import ClientIpStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
@@ -41,16 +42,17 @@ from .e2e_room_keys import EndToEndRoomKeyStore
from .end_to_end_keys import EndToEndKeyStore
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
-from .events import EventsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
from .media_repository import MediaRepositoryStore
+from .metrics import ServerMetricsStore
from .monthly_active_users import MonthlyActiveUsersStore
from .openid import OpenIdStore
from .presence import PresenceStore, UserPresenceState
from .profile import ProfileStore
+from .purge_events import PurgeEventsStore
from .push_rule import PushRuleStore
from .pusher import PusherStore
from .receipts import ReceiptsStore
@@ -66,6 +68,7 @@ from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
+from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore
@@ -86,7 +89,7 @@ class DataStore(
StateStore,
SignatureStore,
ApplicationServiceStore,
- EventsStore,
+ PurgeEventsStore,
EventFederationStore,
MediaRepositoryStore,
RejectionsStore,
@@ -111,26 +114,16 @@ class DataStore(
MonthlyActiveUsersStore,
StatsStore,
RelationsStore,
- CacheInvalidationStore,
+ CensorEventsStore,
+ UIAuthStore,
+ CacheInvalidationWorkerStore,
+ ServerMetricsStore,
):
def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine
- self._stream_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- extra_tables=[("local_invites", "stream_id")],
- )
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
- )
self._presence_id_gen = StreamIdGenerator(
db_conn, "presence_stream", "stream_id"
)
@@ -144,7 +137,10 @@ class DataStore(
db_conn,
"device_lists_stream",
"stream_id",
- extra_tables=[("user_signature_stream", "stream_id")],
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ],
)
self._cross_signing_id_gen = StreamIdGenerator(
db_conn, "e2e_cross_signing_keys", "stream_id"
@@ -154,9 +150,6 @@ class DataStore(
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- )
self._pushers_id_gen = StreamIdGenerator(
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
@@ -165,8 +158,14 @@ class DataStore(
)
if isinstance(self.database_engine, PostgresEngine):
- self._cache_id_gen = StreamIdGenerator(
- db_conn, "cache_invalidation_stream", "stream_id"
+ self._cache_id_gen = MultiWriterIdGenerator(
+ db_conn,
+ database,
+ instance_name="master",
+ table="cache_invalidation_stream_by_instance",
+ instance_column="instance_name",
+ id_column="stream_id",
+ sequence_name="cache_invalidation_stream_seq",
)
else:
self._cache_id_gen = None
@@ -500,7 +499,8 @@ class DataStore(
self, start, limit, name=None, guests=True, deactivated=False
):
"""Function to retrieve a paginated list of users from
- users list. This will return a json list of users.
+ users list. This will return a json list of users and the
+ total number of users matching the filter criteria.
Args:
start (int): start number to begin the query from
@@ -509,35 +509,44 @@ class DataStore(
guests (bool): whether to in include guest users
deactivated (bool): whether to include deactivated users
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ defer.Deferred: resolves to list[dict[str, Any]], int
"""
- name_filter = {}
- if name:
- name_filter["name"] = "%" + name + "%"
-
- attr_filter = {}
- if not guests:
- attr_filter["is_guest"] = 0
- if not deactivated:
- attr_filter["deactivated"] = 0
-
- return self.db.simple_select_list_paginate(
- desc="get_users_paginate",
- table="users",
- orderby="name",
- start=start,
- limit=limit,
- filters=name_filter,
- keyvalues=attr_filter,
- retcols=[
- "name",
- "password_hash",
- "is_guest",
- "admin",
- "user_type",
- "deactivated",
- ],
- )
+
+ def get_users_paginate_txn(txn):
+ filters = []
+ args = []
+
+ if name:
+ filters.append("name LIKE ?")
+ args.append("%" + name + "%")
+
+ if not guests:
+ filters.append("is_guest = 0")
+
+ if not deactivated:
+ filters.append("deactivated = 0")
+
+ where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+ sql = "SELECT COUNT(*) as total_users FROM users %s" % (where_clause)
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ args = [self.hs.config.server_name] + args + [limit, start]
+ sql = """
+ SELECT name, user_type, is_guest, admin, deactivated, displayname, avatar_url
+ FROM users as u
+ LEFT JOIN profiles AS p ON u.name = '@' || p.user_id || ':' || ?
+ {}
+ ORDER BY u.name LIMIT ? OFFSET ?
+ """.format(
+ where_clause
+ )
+ txn.execute(sql, args)
+ users = self.db.cursor_to_dict(txn)
+ return users, count
+
+ return self.db.runInteraction("get_users_paginate_txn", get_users_paginate_txn)
def search_users(self, term):
"""Function to search users list for one or more users with
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 46b494b334..b58f04d00d 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -16,6 +16,7 @@
import abc
import logging
+from typing import List, Tuple
from canonicaljson import json
@@ -175,41 +176,64 @@ class AccountDataWorkerStore(SQLBaseStore):
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
- def get_all_updated_account_data(
- self, last_global_id, last_room_id, current_id, limit
- ):
- """Get all the client account_data that has changed on the server
+ async def get_updated_global_account_data(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple[int, str, str]]:
+ """Get the global account_data that has changed, for the account_data stream
+
Args:
- last_global_id(int): The position to fetch from for top level data
- last_room_id(int): The position to fetch from for per room data
- current_id(int): The position to fetch up to.
+ last_id: the last stream_id from the previous batch.
+ current_id: the maximum stream_id to return up to
+ limit: the maximum number of rows to return
+
Returns:
- A deferred pair of lists of tuples of stream_id int, user_id string,
- room_id string, and type string.
+ A list of tuples of stream_id int, user_id string,
+ and type string.
"""
- if last_room_id == current_id and last_global_id == current_id:
- return defer.succeed(([], []))
+ if last_id == current_id:
+ return []
- def get_updated_account_data_txn(txn):
+ def get_updated_global_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
- txn.execute(sql, (last_global_id, current_id, limit))
- global_results = txn.fetchall()
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ return await self.db.runInteraction(
+ "get_updated_global_account_data", get_updated_global_account_data_txn
+ )
+
+ async def get_updated_room_account_data(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[Tuple[int, str, str, str]]:
+ """Get the global account_data that has changed, for the account_data stream
+
+ Args:
+ last_id: the last stream_id from the previous batch.
+ current_id: the maximum stream_id to return up to
+ limit: the maximum number of rows to return
+
+ Returns:
+ A list of tuples of stream_id int, user_id string,
+ room_id string and type string.
+ """
+ if last_id == current_id:
+ return []
+ def get_updated_room_account_data_txn(txn):
sql = (
"SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
- txn.execute(sql, (last_room_id, current_id, limit))
- room_results = txn.fetchall()
- return global_results, room_results
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
- return self.db.runInteraction(
- "get_all_updated_account_data_txn", get_updated_account_data_txn
+ return await self.db.runInteraction(
+ "get_updated_room_account_data", get_updated_room_account_data_txn
)
def get_updated_account_data_for_user(self, user_id, stream_id):
@@ -273,7 +297,13 @@ class AccountDataWorkerStore(SQLBaseStore):
class AccountDataStore(AccountDataWorkerStore):
def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
- db_conn, "account_data_max_stream_id", "stream_id"
+ db_conn,
+ "account_data_max_stream_id",
+ "stream_id",
+ extra_tables=[
+ ("room_account_data", "stream_id"),
+ ("room_tags_revisions", "stream_id"),
+ ],
)
super(AccountDataStore, self).__init__(database, db_conn, hs)
@@ -363,6 +393,10 @@ class AccountDataStore(AccountDataWorkerStore):
# doesn't sound any worse than the whole update getting lost,
# which is what would happen if we combined the two into one
# transaction.
+ #
+ # Note: This is only here for backwards compat to allow admins to
+ # roll back to a previous Synapse version. Next time we update the
+ # database version we can remove this table.
yield self._update_max_stream_id(next_id)
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
@@ -381,6 +415,10 @@ class AccountDataStore(AccountDataWorkerStore):
next_id(int): The the revision to advance to.
"""
+ # Note: This is only here for backwards compat to allow admins to
+ # roll back to a previous Synapse version. Next time we update the
+ # database version we can remove this table.
+
def _update(txn):
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index efbc06c796..7a1fe8cdd2 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -30,12 +30,12 @@ logger = logging.getLogger(__name__)
def _make_exclusive_regex(services_cache):
- # We precompie a regex constructed from all the regexes that the AS's
+ # We precompile a regex constructed from all the regexes that the AS's
# have registered for exclusive users.
exclusive_user_regexes = [
regex.pattern
for service in services_cache
- for regex in service.get_exlusive_user_regexes()
+ for regex in service.get_exclusive_user_regexes()
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index d4c44dcc75..eac5a4e55b 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -18,9 +18,13 @@ import itertools
import logging
from typing import Any, Iterable, Optional, Tuple
-from twisted.internet import defer
-
+from synapse.api.constants import EventTypes
+from synapse.replication.tcp.streams.events import (
+ EventsStreamCurrentStateRow,
+ EventsStreamEventRow,
+)
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter
@@ -32,7 +36,133 @@ logger = logging.getLogger(__name__)
CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-class CacheInvalidationStore(SQLBaseStore):
+class CacheInvalidationWorkerStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self._instance_name = hs.get_instance_name()
+
+ async def get_all_updated_caches(
+ self, instance_name: str, last_id: int, current_id: int, limit: int
+ ):
+ """Fetches cache invalidation rows between the two given IDs written
+ by the given instance. Returns at most `limit` rows.
+ """
+
+ if last_id == current_id:
+ return []
+
+ def get_all_updated_caches_txn(txn):
+ # We purposefully don't bound by the current token, as we want to
+ # send across cache invalidations as quickly as possible. Cache
+ # invalidations are idempotent, so duplicates are fine.
+ sql = """
+ SELECT stream_id, cache_func, keys, invalidation_ts
+ FROM cache_invalidation_stream_by_instance
+ WHERE stream_id > ? AND instance_name = ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, instance_name, limit))
+ return txn.fetchall()
+
+ return await self.db.runInteraction(
+ "get_all_updated_caches", get_all_updated_caches_txn
+ )
+
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == "events":
+ for row in rows:
+ self._process_event_stream_row(token, row)
+ elif stream_name == "backfill":
+ for row in rows:
+ self._invalidate_caches_for_event(
+ -token,
+ row.event_id,
+ row.room_id,
+ row.type,
+ row.state_key,
+ row.redacts,
+ row.relates_to,
+ backfilled=True,
+ )
+ elif stream_name == "caches":
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(instance_name, token)
+
+ for row in rows:
+ if row.cache_func == CURRENT_STATE_CACHE_NAME:
+ if row.keys is None:
+ raise Exception(
+ "Can't send an 'invalidate all' for current state cache"
+ )
+
+ room_id = row.keys[0]
+ members_changed = set(row.keys[1:])
+ self._invalidate_state_caches(room_id, members_changed)
+ else:
+ self._attempt_to_invalidate_cache(row.cache_func, row.keys)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ def _process_event_stream_row(self, token, row):
+ data = row.data
+
+ if row.type == EventsStreamEventRow.TypeId:
+ self._invalidate_caches_for_event(
+ token,
+ data.event_id,
+ data.room_id,
+ data.type,
+ data.state_key,
+ data.redacts,
+ data.relates_to,
+ backfilled=False,
+ )
+ elif row.type == EventsStreamCurrentStateRow.TypeId:
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ row.data.room_id, token
+ )
+
+ if data.type == EventTypes.Member:
+ self.get_rooms_for_user_with_stream_ordering.invalidate(
+ (data.state_key,)
+ )
+ else:
+ raise Exception("Unknown events stream row type %s" % (row.type,))
+
+ def _invalidate_caches_for_event(
+ self,
+ stream_ordering,
+ event_id,
+ room_id,
+ etype,
+ state_key,
+ redacts,
+ relates_to,
+ backfilled,
+ ):
+ self._invalidate_get_event_cache(event_id)
+
+ self.get_latest_event_ids_in_room.invalidate((room_id,))
+
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
+
+ if not backfilled:
+ self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
+
+ if redacts:
+ self._invalidate_get_event_cache(redacts)
+
+ if etype == EventTypes.Member:
+ self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
+ self.get_invited_rooms_for_local_user.invalidate((state_key,))
+
+ if relates_to:
+ self.get_relations_for_event.invalidate_many((relates_to,))
+ self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+ self.get_applicable_edit.invalidate((relates_to,))
+
async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
"""Invalidates the cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -46,7 +176,7 @@ class CacheInvalidationStore(SQLBaseStore):
return
cache_func.invalidate(keys)
- await self.runInteraction(
+ await self.db.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
cache_func.__name__,
@@ -125,10 +255,7 @@ class CacheInvalidationStore(SQLBaseStore):
# the transaction. However, we want to only get an ID when we want
# to use it, here, so we need to call __enter__ manually, and have
# __exit__ called after the transaction finishes.
- ctx = self._cache_id_gen.get_next()
- stream_id = ctx.__enter__()
- txn.call_on_exception(ctx.__exit__, None, None, None)
- txn.call_after(ctx.__exit__, None, None, None)
+ stream_id = self._cache_id_gen.get_next_txn(txn)
txn.call_after(self.hs.get_notifier().on_new_replication_data)
if keys is not None:
@@ -136,37 +263,18 @@ class CacheInvalidationStore(SQLBaseStore):
self.db.simple_insert_txn(
txn,
- table="cache_invalidation_stream",
+ table="cache_invalidation_stream_by_instance",
values={
"stream_id": stream_id,
+ "instance_name": self._instance_name,
"cache_func": cache_name,
"keys": keys,
"invalidation_ts": self.clock.time_msec(),
},
)
- def get_all_updated_caches(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_caches_txn(txn):
- # We purposefully don't bound by the current token, as we want to
- # send across cache invalidations as quickly as possible. Cache
- # invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
- return txn.fetchall()
-
- return self.db.runInteraction(
- "get_all_updated_caches", get_all_updated_caches_txn
- )
-
- def get_cache_stream_token(self):
+ def get_cache_stream_token(self, instance_name):
if self._cache_id_gen:
- return self._cache_id_gen.get_current_token()
+ return self._cache_id_gen.get_current_token(instance_name)
else:
return 0
diff --git a/synapse/storage/data_stores/main/censor_events.py b/synapse/storage/data_stores/main/censor_events.py
new file mode 100644
index 0000000000..2d48261724
--- /dev/null
+++ b/synapse/storage/data_stores/main/censor_events.py
@@ -0,0 +1,208 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.internet import defer
+
+from synapse.events.utils import prune_event_dict
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.data_stores.main.events import encode_json
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+
+ def _censor_redactions():
+ return run_as_background_process(
+ "_censor_redactions", self._censor_redactions
+ )
+
+ if self.hs.config.redaction_retention_period is not None:
+ hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+
+ async def _censor_redactions(self):
+ """Censors all redactions older than the configured period that haven't
+ been censored yet.
+
+ By censor we mean update the event_json table with the redacted event.
+ """
+
+ if self.hs.config.redaction_retention_period is None:
+ return
+
+ if not (
+ await self.db.updates.has_completed_background_update(
+ "redactions_have_censored_ts_idx"
+ )
+ ):
+ # We don't want to run this until the appropriate index has been
+ # created.
+ return
+
+ before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+
+ # We fetch all redactions that:
+ # 1. point to an event we have,
+ # 2. has a received_ts from before the cut off, and
+ # 3. we haven't yet censored.
+ #
+ # This is limited to 100 events to ensure that we don't try and do too
+ # much at once. We'll get called again so this should eventually catch
+ # up.
+ sql = """
+ SELECT redactions.event_id, redacts FROM redactions
+ LEFT JOIN events AS original_event ON (
+ redacts = original_event.event_id
+ )
+ WHERE NOT have_censored
+ AND redactions.received_ts <= ?
+ ORDER BY redactions.received_ts ASC
+ LIMIT ?
+ """
+
+ rows = await self.db.execute(
+ "_censor_redactions_fetch", None, sql, before_ts, 100
+ )
+
+ updates = []
+
+ for redaction_id, event_id in rows:
+ redaction_event = await self.get_event(redaction_id, allow_none=True)
+ original_event = await self.get_event(
+ event_id, allow_rejected=True, allow_none=True
+ )
+
+ # The SQL above ensures that we have both the redaction and
+ # original event, so if the `get_event` calls return None it
+ # means that the redaction wasn't allowed. Either way we know that
+ # the result won't change so we mark the fact that we've checked.
+ if (
+ redaction_event
+ and original_event
+ and original_event.internal_metadata.is_redacted()
+ ):
+ # Redaction was allowed
+ pruned_json = encode_json(
+ prune_event_dict(
+ original_event.room_version, original_event.get_dict()
+ )
+ )
+ else:
+ # Redaction wasn't allowed
+ pruned_json = None
+
+ updates.append((redaction_id, event_id, pruned_json))
+
+ def _update_censor_txn(txn):
+ for redaction_id, event_id, pruned_json in updates:
+ if pruned_json:
+ self._censor_event_txn(txn, event_id, pruned_json)
+
+ self.db.simple_update_one_txn(
+ txn,
+ table="redactions",
+ keyvalues={"event_id": redaction_id},
+ updatevalues={"have_censored": True},
+ )
+
+ await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+
+ def _censor_event_txn(self, txn, event_id, pruned_json):
+ """Censor an event by replacing its JSON in the event_json table with the
+ provided pruned JSON.
+
+ Args:
+ txn (LoggingTransaction): The database transaction.
+ event_id (str): The ID of the event to censor.
+ pruned_json (str): The pruned JSON
+ """
+ self.db.simple_update_one_txn(
+ txn,
+ table="event_json",
+ keyvalues={"event_id": event_id},
+ updatevalues={"json": pruned_json},
+ )
+
+ @defer.inlineCallbacks
+ def expire_event(self, event_id):
+ """Retrieve and expire an event that has expired, and delete its associated
+ expiry timestamp. If the event can't be retrieved, delete its associated
+ timestamp so we don't try to expire it again in the future.
+
+ Args:
+ event_id (str): The ID of the event to delete.
+ """
+ # Try to retrieve the event's content from the database or the event cache.
+ event = yield self.get_event(event_id)
+
+ def delete_expired_event_txn(txn):
+ # Delete the expiry timestamp associated with this event from the database.
+ self._delete_event_expiry_txn(txn, event_id)
+
+ if not event:
+ # If we can't find the event, log a warning and delete the expiry date
+ # from the database so that we don't try to expire it again in the
+ # future.
+ logger.warning(
+ "Can't expire event %s because we don't have it.", event_id
+ )
+ return
+
+ # Prune the event's dict then convert it to JSON.
+ pruned_json = encode_json(
+ prune_event_dict(event.room_version, event.get_dict())
+ )
+
+ # Update the event_json table to replace the event's JSON with the pruned
+ # JSON.
+ self._censor_event_txn(txn, event.event_id, pruned_json)
+
+ # We need to invalidate the event cache entry for this event because we
+ # changed its content in the database. We can't call
+ # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+ # right type.
+ txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+ # Send that invalidation to replication so that other workers also invalidate
+ # the event cache.
+ self._send_invalidation_to_replication(
+ txn, "_get_event_cache", (event.event_id,)
+ )
+
+ yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+ def _delete_event_expiry_txn(self, txn, event_id):
+ """Delete the expiry timestamp associated with an event ID without deleting the
+ actual event.
+
+ Args:
+ txn (LoggingTransaction): The transaction to use to perform the deletion.
+ event_id (str): The event ID to delete the associated expiry timestamp of.
+ """
+ return self.db.simple_delete_txn(
+ txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+ )
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index e1ccb27142..71f8d43a76 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -21,8 +21,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
-from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.storage.database import Database, make_tuple_comparison_clause
from synapse.util.caches.descriptors import Cache
logger = logging.getLogger(__name__)
@@ -303,16 +302,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
# we'll just end up updating the same device row multiple
# times, which is fine.
- if self.database_engine.supports_tuple_comparison:
- where_clause = "(user_id, device_id) > (?, ?)"
- where_args = [last_user_id, last_device_id]
- else:
- # We explicitly do a `user_id >= ? AND (...)` here to ensure
- # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
- # makes it hard for query optimiser to tell that it can use the
- # index on user_id
- where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
- where_args = [last_user_id, last_user_id, last_device_id]
+ where_clause, where_args = make_tuple_comparison_clause(
+ self.database_engine,
+ [("user_id", last_user_id), ("device_id", last_device_id)],
+ )
sql = """
SELECT
@@ -367,7 +360,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
self.client_ip_last_seen = Cache(
- name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+ name="client_ip_last_seen", keylen=4, max_entries=50000
)
super(ClientIpStore, self).__init__(database, db_conn, hs)
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 0613b49f4a..9a1178fb39 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
+ def get_all_new_device_messages(self, last_pos, current_pos, limit):
+ """
+ Args:
+ last_pos(int):
+ current_pos(int):
+ limit(int):
+ Returns:
+ A deferred list of rows from the device inbox
+ """
+ if last_pos == current_pos:
+ return defer.succeed([])
+
+ def get_all_new_device_messages_txn(txn):
+ # We limit like this as we might have multiple rows per stream_id, and
+ # we want to make sure we always get all entries for any stream_id
+ # we return.
+ upper_pos = min(current_pos, last_pos + limit)
+ sql = (
+ "SELECT max(stream_id), user_id"
+ " FROM device_inbox"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " GROUP BY user_id"
+ )
+ txn.execute(sql, (last_pos, upper_pos))
+ rows = txn.fetchall()
+
+ sql = (
+ "SELECT max(stream_id), destination"
+ " FROM device_federation_outbox"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " GROUP BY destination"
+ )
+ txn.execute(sql, (last_pos, upper_pos))
+ rows.extend(txn)
+
+ # Order by ascending stream ordering
+ rows.sort()
+
+ return rows
+
+ return self.db.runInteraction(
+ "get_all_new_device_messages", get_all_new_device_messages_txn
+ )
+
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
@@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((user_id, device_id, stream_id, message_json))
txn.executemany(sql, rows)
-
- def get_all_new_device_messages(self, last_pos, current_pos, limit):
- """
- Args:
- last_pos(int):
- current_pos(int):
- limit(int):
- Returns:
- A deferred list of rows from the device inbox
- """
- if last_pos == current_pos:
- return defer.succeed([])
-
- def get_all_new_device_messages_txn(txn):
- # We limit like this as we might have multiple rows per stream_id, and
- # we want to make sure we always get all entries for any stream_id
- # we return.
- upper_pos = min(current_pos, last_pos + limit)
- sql = (
- "SELECT max(stream_id), user_id"
- " FROM device_inbox"
- " WHERE ? < stream_id AND stream_id <= ?"
- " GROUP BY user_id"
- )
- txn.execute(sql, (last_pos, upper_pos))
- rows = txn.fetchall()
-
- sql = (
- "SELECT max(stream_id), destination"
- " FROM device_federation_outbox"
- " WHERE ? < stream_id AND stream_id <= ?"
- " GROUP BY destination"
- )
- txn.execute(sql, (last_pos, upper_pos))
- rows.extend(txn)
-
- # Order by ascending stream ordering
- rows.sort()
-
- return rows
-
- return self.db.runInteraction(
- "get_all_new_device_messages", get_all_new_device_messages_txn
- )
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 8af5f7de54..fb9f798e29 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Optional, Set, Tuple
from six import iteritems
@@ -31,7 +32,11 @@ from synapse.logging.opentracing import (
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import (
+ Database,
+ LoggingTransaction,
+ make_tuple_comparison_clause,
+)
from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import (
Cache,
@@ -40,6 +45,7 @@ from synapse.util.caches.descriptors import (
cachedList,
)
from synapse.util.iterutils import batch_iter
+from synapse.util.stringutils import shortstr
logger = logging.getLogger(__name__)
@@ -47,6 +53,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
"drop_device_list_streams_non_unique_indexes"
)
+BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"
+
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
@@ -112,23 +120,13 @@ class DeviceWorkerStore(SQLBaseStore):
if not has_changed:
return now_stream_id, []
- # We retrieve n+1 devices from the list of outbound pokes where n is
- # our outbound device update limit. We then check if the very last
- # device has the same stream_id as the second-to-last device. If so,
- # then we ignore all devices with that stream_id and only send the
- # devices with a lower stream_id.
- #
- # If when culling the list we end up with no devices afterwards, we
- # consider the device update to be too large, and simply skip the
- # stream_id; the rationale being that such a large device list update
- # is likely an error.
updates = yield self.db.runInteraction(
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
- limit + 1,
+ limit,
)
# Return an empty list if there are no updates
@@ -166,14 +164,6 @@ class DeviceWorkerStore(SQLBaseStore):
"device_id": verify_key.version,
}
- # if we have exceeded the limit, we need to exclude any results with the
- # same stream_id as the last row.
- if len(updates) > limit:
- stream_id_cutoff = updates[-1][2]
- now_stream_id = stream_id_cutoff - 1
- else:
- stream_id_cutoff = None
-
# Perform the equivalent of a GROUP BY
#
# Iterate through the updates list and copy non-duplicate
@@ -181,7 +171,6 @@ class DeviceWorkerStore(SQLBaseStore):
# the max stream_id across each set of duplicate entries
#
# maps (user_id, device_id) -> (stream_id, opentracing_context)
- # as long as their stream_id does not match that of the last row
#
# opentracing_context contains the opentracing metadata for the request
# that created the poke
@@ -192,10 +181,6 @@ class DeviceWorkerStore(SQLBaseStore):
query_map = {}
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
- if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
- # Stop processing updates
- break
-
if (
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
@@ -218,17 +203,6 @@ class DeviceWorkerStore(SQLBaseStore):
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)
- # If we didn't find any updates with a stream_id lower than the cutoff, it
- # means that there are more than limit updates all of which have the same
- # steam_id.
-
- # That should only happen if a client is spamming the server with new
- # devices, in which case E2E isn't going to work well anyway. We'll just
- # skip that stream_id and return an empty list, and continue with the next
- # stream_id next time.
- if not query_map and not cross_signing_keys_by_user:
- return stream_id_cutoff, []
-
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
@@ -259,11 +233,11 @@ class DeviceWorkerStore(SQLBaseStore):
# get the list of device updates that need to be sent
sql = """
SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
- WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
+ WHERE destination = ? AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id
LIMIT ?
"""
- txn.execute(sql, (destination, from_stream_id, now_stream_id, False, limit))
+ txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
return list(txn)
@@ -301,7 +275,14 @@ class DeviceWorkerStore(SQLBaseStore):
prev_id = yield self._get_last_device_update_for_remote_user(
destination, user_id, from_stream_id
)
- for device_id, device in iteritems(user_devices):
+
+ # make sure we go through the devices in stream order
+ device_ids = sorted(
+ user_devices.keys(), key=lambda i: query_map[(user_id, i)][0],
+ )
+
+ for device_id in device_ids:
+ device = user_devices[device_id]
stream_id, opentracing_context = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
@@ -361,32 +342,23 @@ class DeviceWorkerStore(SQLBaseStore):
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# We update the device_lists_outbound_last_success with the successfully
- # poked users. We do the join to see which users need to be inserted and
- # which updated.
+ # poked users.
sql = """
- SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL)
+ SELECT user_id, coalesce(max(o.stream_id), 0)
FROM device_lists_outbound_pokes as o
- LEFT JOIN device_lists_outbound_last_success as s
- USING (destination, user_id)
WHERE destination = ? AND o.stream_id <= ?
GROUP BY user_id
"""
txn.execute(sql, (destination, stream_id))
rows = txn.fetchall()
- sql = """
- UPDATE device_lists_outbound_last_success
- SET stream_id = ?
- WHERE destination = ? AND user_id = ?
- """
- txn.executemany(sql, ((row[1], destination, row[0]) for row in rows if row[2]))
-
- sql = """
- INSERT INTO device_lists_outbound_last_success
- (destination, user_id, stream_id) VALUES (?, ?, ?)
- """
- txn.executemany(
- sql, ((destination, row[0], row[1]) for row in rows if not row[2])
+ self.db.simple_upsert_many_txn(
+ txn=txn,
+ table="device_lists_outbound_last_success",
+ key_names=("destination", "user_id"),
+ key_values=((destination, user_id) for user_id, _ in rows),
+ value_names=("stream_id",),
+ value_values=((stream_id,) for _, stream_id in rows),
)
# Delete all sent outbound pokes
@@ -560,8 +532,8 @@ class DeviceWorkerStore(SQLBaseStore):
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- to_check = list(
- self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+ to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
)
if not to_check:
@@ -611,22 +583,33 @@ class DeviceWorkerStore(SQLBaseStore):
else:
return set()
- def get_all_device_list_changes_for_remotes(self, from_key, to_key):
- """Return a list of `(stream_id, user_id, destination)` which is the
- combined list of changes to devices, and which destinations need to be
- poked. `destination` may be None if no destinations need to be poked.
+ async def get_all_device_list_changes_for_remotes(
+ self, from_key: int, to_key: int, limit: int,
+ ) -> List[Tuple[int, str]]:
+ """Return a list of `(stream_id, entity)` which is the combined list of
+ changes to devices and which destinations need to be poked. Entity is
+ either a user ID (starting with '@') or a remote destination.
"""
- # We do a group by here as there can be a large number of duplicate
- # entries, since we throw away device IDs.
+
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
sql = """
- SELECT MAX(stream_id) AS stream_id, user_id, destination
- FROM device_lists_stream
- LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id, destination
+ LIMIT ?
"""
- return self.db.execute(
- "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
+
+ return await self.db.execute(
+ "get_all_device_list_changes_for_remotes",
+ None,
+ sql,
+ from_key,
+ to_key,
+ limit,
)
@cached(max_entries=10000)
@@ -662,21 +645,31 @@ class DeviceWorkerStore(SQLBaseStore):
return results
@defer.inlineCallbacks
- def get_user_ids_requiring_device_list_resync(self, user_ids: Collection[str]):
+ def get_user_ids_requiring_device_list_resync(
+ self, user_ids: Optional[Collection[str]] = None,
+ ) -> Set[str]:
"""Given a list of remote users return the list of users that we
- should resync the device lists for.
+ should resync the device lists for. If None is given instead of a list,
+ return every user that we should resync the device lists for.
Returns:
- Deferred[Set[str]]
+ The IDs of users whose device lists need resync.
"""
-
- rows = yield self.db.simple_select_many_batch(
- table="device_lists_remote_resync",
- column="user_id",
- iterable=user_ids,
- retcols=("user_id",),
- desc="get_user_ids_requiring_device_list_resync",
- )
+ if user_ids:
+ rows = yield self.db.simple_select_many_batch(
+ table="device_lists_remote_resync",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync_with_iterable",
+ )
+ else:
+ rows = yield self.db.simple_select_list(
+ table="device_lists_remote_resync",
+ keyvalues=None,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync",
+ )
return {row["user_id"] for row in rows}
@@ -692,6 +685,25 @@ class DeviceWorkerStore(SQLBaseStore):
desc="make_remote_user_device_cache_as_stale",
)
+ def mark_remote_user_device_list_as_unsubscribed(self, user_id):
+ """Mark that we no longer track device lists for remote user.
+ """
+
+ def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+ self.db.simple_delete_txn(
+ txn,
+ table="device_lists_remote_extremeties",
+ keyvalues={"user_id": user_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
+ )
+
+ return self.db.runInteraction(
+ "mark_remote_user_device_list_as_unsubscribed",
+ _mark_remote_user_device_list_as_unsubscribed_txn,
+ )
+
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
@@ -728,6 +740,20 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
self._drop_device_list_streams_non_unique_indexes,
)
+ # clear out duplicate device list outbound pokes
+ self.db.updates.register_background_update_handler(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, self._remove_duplicate_outbound_pokes,
+ )
+
+ # a pair of background updates that were added during the 1.14 release cycle,
+ # but replaced with 58/06dlols_unique_idx.py
+ self.db.updates.register_noop_background_update(
+ "device_lists_outbound_last_success_unique_idx",
+ )
+ self.db.updates.register_noop_background_update(
+ "drop_device_lists_outbound_last_success_non_unique_idx",
+ )
+
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
@@ -742,6 +768,66 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
)
return 1
+ async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
+ # for some reason, we have accumulated duplicate entries in
+ # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
+ # efficient.
+ #
+ # For each duplicate, we delete all the existing rows and put one back.
+
+ KEY_COLS = ["stream_id", "destination", "user_id", "device_id"]
+ last_row = progress.get(
+ "last_row",
+ {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
+ )
+
+ def _txn(txn):
+ clause, args = make_tuple_comparison_clause(
+ self.db.engine, [(x, last_row[x]) for x in KEY_COLS]
+ )
+ sql = """
+ SELECT stream_id, destination, user_id, device_id, MAX(ts) AS ts
+ FROM device_lists_outbound_pokes
+ WHERE %s
+ GROUP BY %s
+ HAVING count(*) > 1
+ ORDER BY %s
+ LIMIT ?
+ """ % (
+ clause, # WHERE
+ ",".join(KEY_COLS), # GROUP BY
+ ",".join(KEY_COLS), # ORDER BY
+ )
+ txn.execute(sql, args + [batch_size])
+ rows = self.db.cursor_to_dict(txn)
+
+ row = None
+ for row in rows:
+ self.db.simple_delete_txn(
+ txn, "device_lists_outbound_pokes", {x: row[x] for x in KEY_COLS},
+ )
+
+ row["sent"] = False
+ self.db.simple_insert_txn(
+ txn, "device_lists_outbound_pokes", row,
+ )
+
+ if row:
+ self.db.updates._background_update_progress_txn(
+ txn, BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, {"last_row": row},
+ )
+
+ return len(rows)
+
+ rows = await self.db.runInteraction(BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES, _txn)
+
+ if not rows:
+ await self.db.updates._end_background_update(
+ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES
+ )
+
+ return rows
+
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
@@ -878,17 +964,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
desc="update_device",
)
- @defer.inlineCallbacks
- def mark_remote_user_device_list_as_unsubscribed(self, user_id):
- """Mark that we no longer track device lists for remote user.
- """
- yield self.db.simple_delete(
- table="device_lists_remote_extremeties",
- keyvalues={"user_id": user_id},
- desc="mark_remote_user_device_list_as_unsubscribed",
- )
- self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
-
def update_remote_device_list_cache_entry(
self, user_id, device_id, content, stream_id
):
@@ -1021,29 +1096,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
- with self._device_list_id_gen.get_next() as stream_id:
+ if not device_ids:
+ return
+
+ with self._device_list_id_gen.get_next_mult(len(device_ids)) as stream_ids:
+ yield self.db.runInteraction(
+ "add_device_change_to_stream",
+ self._add_device_change_to_stream_txn,
+ user_id,
+ device_ids,
+ stream_ids,
+ )
+
+ if not hosts:
+ return stream_ids[-1]
+
+ context = get_active_span_text_map()
+ with self._device_list_id_gen.get_next_mult(
+ len(hosts) * len(device_ids)
+ ) as stream_ids:
yield self.db.runInteraction(
- "add_device_change_to_streams",
- self._add_device_change_txn,
+ "add_device_outbound_poke_to_stream",
+ self._add_device_outbound_poke_to_stream_txn,
user_id,
device_ids,
hosts,
- stream_id,
+ stream_ids,
+ context,
)
- return stream_id
- def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
- now = self._clock.time_msec()
+ return stream_ids[-1]
+ def _add_device_change_to_stream_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ stream_ids: List[str],
+ ):
txn.call_after(
- self._device_list_stream_cache.entity_has_changed, user_id, stream_id
+ self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
- for host in hosts:
- txn.call_after(
- self._device_list_federation_stream_cache.entity_has_changed,
- host,
- stream_id,
- )
+
+ min_stream_id = stream_ids[0]
# Delete older entries in the table, as we really only care about
# when the latest change happened.
@@ -1052,7 +1147,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
- [(user_id, device_id, stream_id) for device_id in device_ids],
+ [(user_id, device_id, min_stream_id) for device_id in device_ids],
)
self.db.simple_insert_many_txn(
@@ -1060,11 +1155,22 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
table="device_lists_stream",
values=[
{"stream_id": stream_id, "user_id": user_id, "device_id": device_id}
- for device_id in device_ids
+ for stream_id, device_id in zip(stream_ids, device_ids)
],
)
- context = get_active_span_text_map()
+ def _add_device_outbound_poke_to_stream_txn(
+ self, txn, user_id, device_ids, hosts, stream_ids, context,
+ ):
+ for host in hosts:
+ txn.call_after(
+ self._device_list_federation_stream_cache.entity_has_changed,
+ host,
+ stream_ids[-1],
+ )
+
+ now = self._clock.time_msec()
+ next_stream_id = iter(stream_ids)
self.db.simple_insert_many_txn(
txn,
@@ -1072,7 +1178,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
{
"destination": destination,
- "stream_id": stream_id,
+ "stream_id": next(next_stream_id),
"user_id": user_id,
"device_id": device_id,
"sent": False,
@@ -1086,18 +1192,47 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
],
)
- def _prune_old_outbound_device_pokes(self):
+ def _prune_old_outbound_device_pokes(self, prune_age=24 * 60 * 60 * 1000):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
- that we don't fill up due to dead servers. We keep one entry per
- (destination, user_id) tuple to ensure that the prev_ids remain correct
- if the server does come back.
+ that we don't fill up due to dead servers.
+
+ Normally, we try to send device updates as a delta since a previous known point:
+ this is done by setting the prev_id in the m.device_list_update EDU. However,
+ for that to work, we have to have a complete record of each change to
+ each device, which can add up to quite a lot of data.
+
+ An alternative mechanism is that, if the remote server sees that it has missed
+ an entry in the stream_id sequence for a given user, it will request a full
+ list of that user's devices. Hence, we can reduce the amount of data we have to
+ store (and transmit in some future transaction), by clearing almost everything
+ for a given destination out of the database, and having the remote server
+ resync.
+
+ All we need to do is make sure we keep at least one row for each
+ (user, destination) pair, to remind us to send a m.device_list_update EDU for
+ that user when the destination comes back. It doesn't matter which device
+ we keep.
"""
- yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
+ yesterday = self._clock.time_msec() - prune_age
def _prune_txn(txn):
+ # look for (user, destination) pairs which have an update older than
+ # the cutoff.
+ #
+ # For each pair, we also need to know the most recent stream_id, and
+ # an arbitrary device_id at that stream_id.
select_sql = """
- SELECT destination, user_id, max(stream_id) as stream_id
- FROM device_lists_outbound_pokes
+ SELECT
+ dlop1.destination,
+ dlop1.user_id,
+ MAX(dlop1.stream_id) AS stream_id,
+ (SELECT MIN(dlop2.device_id) AS device_id FROM
+ device_lists_outbound_pokes dlop2
+ WHERE dlop2.destination = dlop1.destination AND
+ dlop2.user_id=dlop1.user_id AND
+ dlop2.stream_id=MAX(dlop1.stream_id)
+ )
+ FROM device_lists_outbound_pokes dlop1
GROUP BY destination, user_id
HAVING min(ts) < ? AND count(*) > 1
"""
@@ -1108,14 +1243,29 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not rows:
return
+ logger.info(
+ "Pruning old outbound device list updates for %i users/destinations: %s",
+ len(rows),
+ shortstr((row[0], row[1]) for row in rows),
+ )
+
+ # we want to keep the update with the highest stream_id for each user.
+ #
+ # there might be more than one update (with different device_ids) with the
+ # same stream_id, so we also delete all but one rows with the max stream id.
delete_sql = """
DELETE FROM device_lists_outbound_pokes
- WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
+ WHERE destination = ? AND user_id = ? AND (
+ stream_id < ? OR
+ (stream_id = ? AND device_id != ?)
+ )
"""
-
- txn.executemany(
- delete_sql, ((yesterday, row[0], row[1], row[2]) for row in rows)
- )
+ count = 0
+ for (destination, user_id, stream_id, device_id) in rows:
+ txn.execute(
+ delete_sql, (destination, user_id, stream_id, stream_id, device_id)
+ )
+ count += txn.rowcount
# Since we've deleted unsent deltas, we need to remove the entry
# of last successful sent so that the prev_ids are correctly set.
@@ -1125,7 +1275,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
"""
txn.executemany(sql, ((row[0], row[1]) for row in rows))
- logger.info("Pruned %d device list outbound pokes", txn.rowcount)
+ logger.info("Pruned %d device list outbound pokes", count)
return run_as_background_process(
"prune_old_outbound_device_pokes",
diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py
index c9e7de7d12..e1d1bc3e05 100644
--- a/synapse/storage/data_stores/main/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -14,6 +14,7 @@
# limitations under the License.
from collections import namedtuple
+from typing import Optional
from twisted.internet import defer
@@ -159,10 +160,29 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
- def update_aliases_for_room(self, old_room_id, new_room_id, creator):
+ def update_aliases_for_room(
+ self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
+ ):
+ """Repoint all of the aliases for a given room, to a different room.
+
+ Args:
+ old_room_id:
+ new_room_id:
+ creator: The user to record as the creator of the new mapping.
+ If None, the creator will be left unchanged.
+ """
+
def _update_aliases_for_room_txn(txn):
- sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?"
- txn.execute(sql, (new_room_id, creator, old_room_id))
+ update_creator_sql = ""
+ sql_params = (new_room_id, old_room_id)
+ if creator:
+ update_creator_sql = ", creator = ?"
+ sql_params = (new_room_id, creator, old_room_id)
+
+ sql = "UPDATE room_aliases SET room_id = ? %s WHERE room_id = ?" % (
+ update_creator_sql,
+ )
+ txn.execute(sql, sql_params)
self._invalidate_cache_and_stream(
txn, self.get_aliases_for_room, (old_room_id,)
)
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 84594cf0a9..23f4570c4b 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -146,7 +146,8 @@ class EndToEndRoomKeyStore(SQLBaseStore):
room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
- "is_verified": row["is_verified"],
+ # is_verified must be returned to the client as a boolean
+ "is_verified": bool(row["is_verified"]),
"session_data": json.loads(row["session_data"]),
}
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 001a53f9b4..20698bfd16 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -25,7 +25,9 @@ from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import make_in_list_sql_clause
from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -268,53 +270,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys
)
- def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
- """Returns a user's cross-signing key.
-
- Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- user_id (str): the user whose key is being requested
- key_type (str): the type of key that is being requested: either 'master'
- for a master key, 'self_signing' for a self-signing key, or
- 'user_signing' for a user-signing key
- from_user_id (str): if specified, signatures made by this user on
- the key will be included in the result
-
- Returns:
- dict of the key data or None if not found
- """
- sql = (
- "SELECT keydata "
- " FROM e2e_cross_signing_keys "
- " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
- )
- txn.execute(sql, (user_id, key_type))
- row = txn.fetchone()
- if not row:
- return None
- key = json.loads(row[0])
-
- device_id = None
- for k in key["keys"].values():
- device_id = k
-
- if from_user_id is not None:
- sql = (
- "SELECT key_id, signature "
- " FROM e2e_cross_signing_signatures "
- " WHERE user_id = ? "
- " AND target_user_id = ? "
- " AND target_device_id = ? "
- )
- txn.execute(sql, (from_user_id, user_id, device_id))
- row = txn.fetchone()
- if row:
- key.setdefault("signatures", {}).setdefault(from_user_id, {})[
- row[0]
- ] = row[1]
-
- return key
-
+ @defer.inlineCallbacks
def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
"""Returns a user's cross-signing key.
@@ -329,13 +285,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Returns:
dict of the key data or None if not found
"""
- return self.db.runInteraction(
- "get_e2e_cross_signing_key",
- self._get_e2e_cross_signing_key_txn,
- user_id,
- key_type,
- from_user_id,
- )
+ res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
+ user_keys = res.get(user_id)
+ if not user_keys:
+ return None
+ return user_keys.get(key_type)
@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id):
@@ -391,26 +345,24 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""
result = {}
- batch_size = 100
- chunks = [
- user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
- ]
- for user_chunk in chunks:
- sql = """
+ for user_chunk in batch_iter(user_ids, 100):
+ clause, params = make_in_list_sql_clause(
+ txn.database_engine, "k.user_id", user_chunk
+ )
+ sql = (
+ """
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype)
- WHERE k.user_id IN (%s)
- """ % (
- ",".join("?" for u in user_chunk),
+ WHERE
+ """
+ + clause
)
- query_params = []
- query_params.extend(user_chunk)
- txn.execute(sql, query_params)
+ txn.execute(sql, params)
rows = self.db.cursor_to_dict(txn)
for row in rows:
@@ -453,15 +405,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
device_id = k
devices[(user_id, device_id)] = key_type
- device_list = list(devices)
-
- # split into batches
- batch_size = 100
- chunks = [
- device_list[i : i + batch_size]
- for i in range(0, len(device_list), batch_size)
- ]
- for user_chunk in chunks:
+ for batch in batch_iter(devices.keys(), size=100):
sql = """
SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures
@@ -469,11 +413,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
AND (%s)
""" % (
" OR ".join(
- "(target_user_id = ? AND target_device_id = ?)" for d in devices
+ "(target_user_id = ? AND target_device_id = ?)" for _ in batch
)
)
query_params = [from_user_id]
- for item in devices:
+ for item in batch:
# item is a (user_id, device_id) tuple
query_params.extend(item)
@@ -537,7 +481,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return result
- def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+ def get_all_user_signature_changes_for_remotes(self, from_key, to_key, limit):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
device with their user-signing key, which is not published to other
@@ -552,13 +496,19 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
"""
sql = """
- SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+ SELECT stream_id, from_user_id AS user_id
FROM user_signature_stream
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id
+ ORDER BY stream_id ASC
+ LIMIT ?
"""
return self.db.execute(
- "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+ "get_all_user_signature_changes_for_remotes",
+ None,
+ sql,
+ from_key,
+ to_key,
+ limit,
)
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 62d4e9f599..24ce8c4330 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -173,19 +173,28 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
for event_id in initial_events
}
+ # The sorted list of events whose auth chains we should walk.
+ search = [] # type: List[Tuple[int, str]]
+
# We need to get the depth of the initial events for sorting purposes.
sql = """
SELECT depth, event_id FROM events
WHERE %s
- ORDER BY depth ASC
"""
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "event_id", initial_events
- )
- txn.execute(sql % (clause,), args)
+ # the list can be huge, so let's avoid looking them all up in one massive
+ # query.
+ for batch in batch_iter(initial_events, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
- # The sorted list of events whose auth chains we should walk.
- search = txn.fetchall() # type: List[Tuple[int, str]]
+ # I think building a temporary list with fetchall is more efficient than
+ # just `search.extend(txn)`, but this is unconfirmed
+ search.extend(txn.fetchall())
+
+ # sort by depth
+ search.sort()
# Map from event to its auth events
event_to_auth_events = {} # type: Dict[str, Set[str]]
@@ -631,89 +640,6 @@ class EventFederationStore(EventFederationWorkerStore):
self._delete_old_forward_extrem_cache, 60 * 60 * 1000
)
- def _update_min_depth_for_room_txn(self, txn, room_id, depth):
- min_depth = self._get_min_depth_interaction(txn, room_id)
-
- if min_depth is not None and depth >= min_depth:
- return
-
- self.db.simple_upsert_txn(
- txn,
- table="room_depth",
- keyvalues={"room_id": room_id},
- values={"min_depth": depth},
- )
-
- def _handle_mult_prev_events(self, txn, events):
- """
- For the given event, update the event edges table and forward and
- backward extremities tables.
- """
- self.db.simple_insert_many_txn(
- txn,
- table="event_edges",
- values=[
- {
- "event_id": ev.event_id,
- "prev_event_id": e_id,
- "room_id": ev.room_id,
- "is_state": False,
- }
- for ev in events
- for e_id in ev.prev_event_ids()
- ],
- )
-
- self._update_backward_extremeties(txn, events)
-
- def _update_backward_extremeties(self, txn, events):
- """Updates the event_backward_extremities tables based on the new/updated
- events being persisted.
-
- This is called for new events *and* for events that were outliers, but
- are now being persisted as non-outliers.
-
- Forward extremities are handled when we first start persisting the events.
- """
- events_by_room = {}
- for ev in events:
- events_by_room.setdefault(ev.room_id, []).append(ev)
-
- query = (
- "INSERT INTO event_backward_extremities (event_id, room_id)"
- " SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- " )"
- " AND NOT EXISTS ("
- " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
- " AND outlier = ?"
- " )"
- )
-
- txn.executemany(
- query,
- [
- (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
- for ev in events
- for e_id in ev.prev_event_ids()
- if not ev.internal_metadata.is_outlier()
- ],
- )
-
- query = (
- "DELETE FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
- )
- txn.executemany(
- query,
- [
- (ev.event_id, ev.room_id)
- for ev in events
- if not ev.internal_metadata.is_outlier()
- ],
- )
-
def _delete_old_forward_extrem_cache(self):
def _delete_old_forward_extrem_cache_txn(txn):
# Delete entries older than a month, while making sure we don't delete
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 8eed590929..0321274de2 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -652,69 +652,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
self._start_rotate_notifs, 30 * 60 * 1000
)
- def _set_push_actions_for_event_and_users_txn(
- self, txn, events_and_contexts, all_events_and_contexts
- ):
- """Handles moving push actions from staging table to main
- event_push_actions table for all events in `events_and_contexts`.
-
- Also ensures that all events in `all_events_and_contexts` are removed
- from the push action staging area.
-
- Args:
- events_and_contexts (list[(EventBase, EventContext)]): events
- we are persisting
- all_events_and_contexts (list[(EventBase, EventContext)]): all
- events that we were going to persist. This includes events
- we've already persisted, etc, that wouldn't appear in
- events_and_context.
- """
-
- sql = """
- INSERT INTO event_push_actions (
- room_id, event_id, user_id, actions, stream_ordering,
- topological_ordering, notif, highlight
- )
- SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
- FROM event_push_actions_staging
- WHERE event_id = ?
- """
-
- if events_and_contexts:
- txn.executemany(
- sql,
- (
- (
- event.room_id,
- event.internal_metadata.stream_ordering,
- event.depth,
- event.event_id,
- )
- for event, _ in events_and_contexts
- ),
- )
-
- for event, _ in events_and_contexts:
- user_ids = self.db.simple_select_onecol_txn(
- txn,
- table="event_push_actions_staging",
- keyvalues={"event_id": event.event_id},
- retcol="user_id",
- )
-
- for uid in user_ids:
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (event.room_id, uid),
- )
-
- # Now we delete the staging area for *all* events that were being
- # persisted.
- txn.executemany(
- "DELETE FROM event_push_actions_staging WHERE event_id = ?",
- ((event.event_id,) for event, _ in all_events_and_contexts),
- )
-
@defer.inlineCallbacks
def get_push_actions_for_user(
self, user_id, before=None, limit=50, only_highlight=False
@@ -763,17 +700,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
return result[0] or 0
- def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
- # Sad that we have to blow away the cache for the whole room here
- txn.call_after(
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
- (room_id,),
- )
- txn.execute(
- "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
- (room_id, event_id),
- )
-
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index d593ef47b8..a6572571b4 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,39 +17,44 @@
import itertools
import logging
-from collections import Counter as c_counter, OrderedDict, namedtuple
+from collections import OrderedDict, namedtuple
from functools import wraps
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
-from six import iteritems, text_type
+from six import integer_types, iteritems, text_type
from six.moves import range
+import attr
from canonicaljson import json
from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
-from synapse.api.constants import EventContentFields, EventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RelationTypes,
+)
from synapse.api.room_versions import RoomVersions
+from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
-from synapse.events.utils import prune_event_dict
from synapse.logging.utils import log_function
-from synapse.metrics import BucketCollector
-from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import make_in_list_sql_clause
-from synapse.storage.data_stores.main.event_federation import EventFederationStore
-from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
-from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.data_stores.main.search import SearchEntry
from synapse.storage.database import Database, LoggingTransaction
-from synapse.storage.persist_events import DeltaState
-from synapse.types import RoomStreamToken, StateMap, get_domain_from_id
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
+if TYPE_CHECKING:
+ from synapse.storage.data_stores.main import DataStore
+ from synapse.server import HomeServer
+
+
logger = logging.getLogger(__name__)
persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
@@ -94,58 +99,49 @@ def _retry_on_integrity_error(func):
return f
-# inherits from EventFederationStore so that we can call _update_backward_extremities
-# and _handle_mult_prev_events (though arguably those could both be moved in here)
-class EventsStore(
- StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
-):
- def __init__(self, database: Database, db_conn, hs):
- super(EventsStore, self).__init__(database, db_conn, hs)
+@attr.s(slots=True)
+class DeltaState:
+ """Deltas to use to update the `current_state_events` table.
- # Collect metrics on the number of forward extremities that exist.
- # Counter of number of extremities to count
- self._current_forward_extremities_amount = c_counter()
+ Attributes:
+ to_delete: List of type/state_keys to delete from current state
+ to_insert: Map of state to upsert into current state
+ no_longer_in_room: The server is not longer in the room, so the room
+ should e.g. be removed from `current_state_events` table.
+ """
- BucketCollector(
- "synapse_forward_extremities",
- lambda: self._current_forward_extremities_amount,
- buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
- )
+ to_delete = attr.ib(type=List[Tuple[str, str]])
+ to_insert = attr.ib(type=StateMap[str])
+ no_longer_in_room = attr.ib(type=bool, default=False)
- # Read the extrems every 60 minutes
- def read_forward_extremities():
- # run as a background process to make sure that the database transactions
- # have a logcontext to report to
- return run_as_background_process(
- "read_forward_extremities", self._read_forward_extremities
- )
- hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+class PersistEventsStore:
+ """Contains all the functions for writing events to the database.
- def _censor_redactions():
- return run_as_background_process(
- "_censor_redactions", self._censor_redactions
- )
+ Should only be instantiated on one process (when using a worker mode setup).
+
+ Note: This is not part of the `DataStore` mixin.
+ """
- if self.hs.config.redaction_retention_period is not None:
- hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+ def __init__(self, hs: "HomeServer", db: Database, main_data_store: "DataStore"):
+ self.hs = hs
+ self.db = db
+ self.store = main_data_store
+ self.database_engine = db.engine
+ self._clock = hs.get_clock()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
- @defer.inlineCallbacks
- def _read_forward_extremities(self):
- def fetch(txn):
- txn.execute(
- """
- select count(*) c from event_forward_extremities
- group by room_id
- """
- )
- return txn.fetchall()
+ # Ideally we'd move these ID gens here, unfortunately some other ID
+ # generators are chained off them so doing so is a bit of a PITA.
+ self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator
+ self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator
- res = yield self.db.runInteraction("read_forward_extremities", fetch)
- self._current_forward_extremities_amount = c_counter([x[0] for x in res])
+ # This should only exist on instances that are configured to write
+ assert (
+ hs.config.worker.writers.events == hs.get_instance_name()
+ ), "Can only instantiate EventsStore on master"
@_retry_on_integrity_error
@defer.inlineCallbacks
@@ -237,10 +233,10 @@ class EventsStore(
event_counter.labels(event.type, origin_type, origin_entity).inc()
for room_id, new_state in iteritems(current_state_for_room):
- self.get_current_state_ids.prefill((room_id,), new_state)
+ self.store.get_current_state_ids.prefill((room_id,), new_state)
for room_id, latest_event_ids in iteritems(new_forward_extremeties):
- self.get_latest_event_ids_in_room.prefill(
+ self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
)
@@ -586,7 +582,7 @@ class EventsStore(
)
txn.call_after(
- self._curr_state_delta_stream_cache.entity_has_changed,
+ self.store._curr_state_delta_stream_cache.entity_has_changed,
room_id,
stream_id,
)
@@ -606,10 +602,13 @@ class EventsStore(
for member in members_changed:
txn.call_after(
- self.get_rooms_for_user_with_stream_ordering.invalidate, (member,)
+ self.store.get_rooms_for_user_with_stream_ordering.invalidate,
+ (member,),
)
- self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+ self.store._invalidate_state_caches_and_stream(
+ txn, room_id, members_changed
+ )
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
"""Update the room version in the database based off current state
@@ -647,7 +646,9 @@ class EventsStore(
self.db.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
- txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
+ txn.call_after(
+ self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
+ )
self.db.simple_insert_many_txn(
txn,
@@ -713,10 +714,10 @@ class EventsStore(
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
- txn.call_after(self._invalidate_get_event_cache, event.event_id)
+ txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
if not backfilled:
txn.call_after(
- self._events_stream_cache.entity_has_changed,
+ self.store._events_stream_cache.entity_has_changed,
event.room_id,
event.internal_metadata.stream_ordering,
)
@@ -1088,13 +1089,15 @@ class EventsStore(
def prefill():
for cache_entry in to_prefill:
- self._get_event_cache.prefill((cache_entry[0].event_id,), cache_entry)
+ self.store._get_event_cache.prefill(
+ (cache_entry[0].event_id,), cache_entry
+ )
txn.call_after(prefill)
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
- txn.call_after(self._invalidate_get_event_cache, event.redacts)
+ txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
self.db.simple_insert_txn(
txn,
@@ -1106,897 +1109,512 @@ class EventsStore(
},
)
- async def _censor_redactions(self):
- """Censors all redactions older than the configured period that haven't
- been censored yet.
-
- By censor we mean update the event_json table with the redacted event.
- """
-
- if self.hs.config.redaction_retention_period is None:
- return
-
- if not (
- await self.db.updates.has_completed_background_update(
- "redactions_have_censored_ts_idx"
- )
- ):
- # We don't want to run this until the appropriate index has been
- # created.
- return
-
- before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+ def insert_labels_for_event_txn(
+ self, txn, event_id, labels, room_id, topological_ordering
+ ):
+ """Store the mapping between an event's ID and its labels, with one row per
+ (event_id, label) tuple.
- # We fetch all redactions that:
- # 1. point to an event we have,
- # 2. has a received_ts from before the cut off, and
- # 3. we haven't yet censored.
- #
- # This is limited to 100 events to ensure that we don't try and do too
- # much at once. We'll get called again so this should eventually catch
- # up.
- sql = """
- SELECT redactions.event_id, redacts FROM redactions
- LEFT JOIN events AS original_event ON (
- redacts = original_event.event_id
- )
- WHERE NOT have_censored
- AND redactions.received_ts <= ?
- ORDER BY redactions.received_ts ASC
- LIMIT ?
+ Args:
+ txn (LoggingTransaction): The transaction to execute.
+ event_id (str): The event's ID.
+ labels (list[str]): A list of text labels.
+ room_id (str): The ID of the room the event was sent to.
+ topological_ordering (int): The position of the event in the room's topology.
"""
-
- rows = await self.db.execute(
- "_censor_redactions_fetch", None, sql, before_ts, 100
+ return self.db.simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": room_id,
+ "topological_ordering": topological_ordering,
+ }
+ for label in labels
+ ],
)
- updates = []
-
- for redaction_id, event_id in rows:
- redaction_event = await self.get_event(redaction_id, allow_none=True)
- original_event = await self.get_event(
- event_id, allow_rejected=True, allow_none=True
- )
-
- # The SQL above ensures that we have both the redaction and
- # original event, so if the `get_event` calls return None it
- # means that the redaction wasn't allowed. Either way we know that
- # the result won't change so we mark the fact that we've checked.
- if (
- redaction_event
- and original_event
- and original_event.internal_metadata.is_redacted()
- ):
- # Redaction was allowed
- pruned_json = encode_json(
- prune_event_dict(
- original_event.room_version, original_event.get_dict()
- )
- )
- else:
- # Redaction wasn't allowed
- pruned_json = None
-
- updates.append((redaction_id, event_id, pruned_json))
-
- def _update_censor_txn(txn):
- for redaction_id, event_id, pruned_json in updates:
- if pruned_json:
- self._censor_event_txn(txn, event_id, pruned_json)
-
- self.db.simple_update_one_txn(
- txn,
- table="redactions",
- keyvalues={"event_id": redaction_id},
- updatevalues={"have_censored": True},
- )
-
- await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
-
- def _censor_event_txn(self, txn, event_id, pruned_json):
- """Censor an event by replacing its JSON in the event_json table with the
- provided pruned JSON.
+ def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+ """Save the expiry timestamp associated with a given event ID.
Args:
- txn (LoggingTransaction): The database transaction.
- event_id (str): The ID of the event to censor.
- pruned_json (str): The pruned JSON
+ txn (LoggingTransaction): The database transaction to use.
+ event_id (str): The event ID the expiry timestamp is associated with.
+ expiry_ts (int): The timestamp at which to expire (delete) the event.
"""
- self.db.simple_update_one_txn(
- txn,
- table="event_json",
- keyvalues={"event_id": event_id},
- updatevalues={"json": pruned_json},
+ return self.db.simple_insert_txn(
+ txn=txn,
+ table="event_expiry",
+ values={"event_id": event_id, "expiry_ts": expiry_ts},
)
- @defer.inlineCallbacks
- def count_daily_messages(self):
- """
- Returns an estimate of the number of messages sent in the last day.
-
- If it has been significantly less or more than one day since the last
- call to this function, it will return None.
+ def _store_event_reference_hashes_txn(self, txn, events):
+ """Store a hash for a PDU
+ Args:
+ txn (cursor):
+ events (list): list of Events.
"""
- def _count_messages(txn):
- sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
- WHERE type = 'm.room.message'
- AND stream_ordering > ?
- """
- txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
- return count
-
- ret = yield self.db.runInteraction("count_messages", _count_messages)
- return ret
-
- @defer.inlineCallbacks
- def count_daily_sent_messages(self):
- def _count_messages(txn):
- # This is good enough as if you have silly characters in your own
- # hostname then thats your own fault.
- like_clause = "%:" + self.hs.hostname
-
- sql = """
- SELECT COALESCE(COUNT(*), 0) FROM events
- WHERE type = 'm.room.message'
- AND sender LIKE ?
- AND stream_ordering > ?
- """
-
- txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- (count,) = txn.fetchone()
- return count
-
- ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
- return ret
-
- @defer.inlineCallbacks
- def count_daily_active_rooms(self):
- def _count(txn):
- sql = """
- SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
- WHERE type = 'm.room.message'
- AND stream_ordering > ?
- """
- txn.execute(sql, (self.stream_ordering_day_ago,))
- (count,) = txn.fetchone()
- return count
-
- ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
- return ret
-
- def get_current_backfill_token(self):
- """The current minimum token that backfilled events have reached"""
- return -self._backfill_id_gen.get_current_token()
-
- def get_current_events_token(self):
- """The current maximum token that events have reached"""
- return self._stream_id_gen.get_current_token()
-
- def get_all_new_forward_event_rows(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_new_forward_event_rows(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- new_event_updates = txn.fetchall()
-
- if len(new_event_updates) == limit:
- upper_bound = new_event_updates[-1][0]
- else:
- upper_bound = current_id
-
- sql = (
- "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? < event_stream_ordering"
- " AND event_stream_ordering <= ?"
- " ORDER BY event_stream_ordering DESC"
+ vals = []
+ for event in events:
+ ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+ vals.append(
+ {
+ "event_id": event.event_id,
+ "algorithm": ref_alg,
+ "hash": memoryview(ref_hash_bytes),
+ }
)
- txn.execute(sql, (last_id, upper_bound))
- new_event_updates.extend(txn)
- return new_event_updates
+ self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
- return self.db.runInteraction(
- "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+ def _store_room_members_txn(self, txn, events, backfilled):
+ """Store a room member in the database.
+ """
+ self.db.simple_insert_many_txn(
+ txn,
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
+ }
+ for event in events
+ ],
)
- def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_new_backfill_event_rows(txn):
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
+ for event in events:
+ txn.call_after(
+ self.store._membership_stream_cache.entity_has_changed,
+ event.state_key,
+ event.internal_metadata.stream_ordering,
)
- txn.execute(sql, (-last_id, -current_id, limit))
- new_event_updates = txn.fetchall()
-
- if len(new_event_updates) == limit:
- upper_bound = new_event_updates[-1][0]
- else:
- upper_bound = current_id
-
- sql = (
- "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts, relates_to_id"
- " FROM events AS e"
- " INNER JOIN ex_outlier_stream USING (event_id)"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " LEFT JOIN event_relations USING (event_id)"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
+ txn.call_after(
+ self.store.get_invited_rooms_for_local_user.invalidate,
+ (event.state_key,),
)
- txn.execute(sql, (-last_id, -upper_bound))
- new_event_updates.extend(txn.fetchall())
-
- return new_event_updates
- return self.db.runInteraction(
- "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
- )
-
- @cached(num_args=5, max_entries=10)
- def get_all_new_events(
- self,
- last_backfill_id,
- last_forward_id,
- current_backfill_id,
- current_forward_id,
- limit,
- ):
- """Get all the new events that have arrived at the server either as
- new events or as backfilled events"""
- have_backfill_events = last_backfill_id != current_backfill_id
- have_forward_events = last_forward_id != current_forward_id
-
- if not have_backfill_events and not have_forward_events:
- return defer.succeed(AllNewEventsResult([], [], [], [], []))
-
- def get_all_new_events_txn(txn):
- sql = (
- "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? < stream_ordering AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- " LIMIT ?"
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened. If the event is an
+ # outlier it is only current if its an "out of band membership",
+ # like a remote invite or a rejection of a remote invite.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_out_of_band_membership()
)
- if have_forward_events:
- txn.execute(sql, (last_forward_id, current_forward_id, limit))
- new_forward_events = txn.fetchall()
-
- if len(new_forward_events) == limit:
- upper_bound = new_forward_events[-1][0]
+ is_mine = self.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self.db.simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ },
+ )
else:
- upper_bound = current_forward_id
-
- sql = (
- "SELECT event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (last_forward_id, upper_bound))
- forward_ex_outliers = txn.fetchall()
- else:
- new_forward_events = []
- forward_ex_outliers = []
-
- sql = (
- "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
- " FROM events AS e"
- " LEFT JOIN redactions USING (event_id)"
- " LEFT JOIN state_events USING (event_id)"
- " WHERE ? > stream_ordering AND stream_ordering >= ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- )
- if have_backfill_events:
- txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
- new_backfill_events = txn.fetchall()
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
- if len(new_backfill_events) == limit:
- upper_bound = new_backfill_events[-1][0]
- else:
- upper_bound = current_backfill_id
-
- sql = (
- "SELECT -event_stream_ordering, event_id, state_group"
- " FROM ex_outlier_stream"
- " WHERE ? > event_stream_ordering"
- " AND event_stream_ordering >= ?"
- " ORDER BY event_stream_ordering DESC"
- )
- txn.execute(sql, (-last_backfill_id, -upper_bound))
- backward_ex_outliers = txn.fetchall()
- else:
- new_backfill_events = []
- backward_ex_outliers = []
-
- return AllNewEventsResult(
- new_forward_events,
- new_backfill_events,
- forward_ex_outliers,
- backward_ex_outliers,
- )
+ txn.execute(
+ sql,
+ (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ),
+ )
- return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
+ # We also update the `local_current_membership` table with
+ # latest invite info. This will usually get updated by the
+ # `current_state_events` handling, unless its an outlier.
+ if event.internal_metadata.is_outlier():
+ # This should only happen for out of band memberships, so
+ # we add a paranoia check.
+ assert event.internal_metadata.is_out_of_band_membership()
+
+ self.db.simple_upsert_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={
+ "room_id": event.room_id,
+ "user_id": event.state_key,
+ },
+ values={
+ "event_id": event.event_id,
+ "membership": event.membership,
+ },
+ )
- def purge_history(self, room_id, token, delete_local_events):
- """Deletes room history before a certain point
+ def _handle_event_relations(self, txn, event):
+ """Handles inserting relation data during peristence of events
Args:
- room_id (str):
-
- token (str): A topological token to delete events before
-
- delete_local_events (bool):
- if True, we will delete local events as well as remote ones
- (instead of just marking them as outliers and deleting their
- state groups).
-
- Returns:
- Deferred[set[int]]: The set of state groups that are referenced by
- deleted events.
+ txn
+ event (EventBase)
"""
+ relation = event.content.get("m.relates_to")
+ if not relation:
+ # No relations
+ return
- return self.db.runInteraction(
- "purge_history",
- self._purge_history_txn,
- room_id,
- token,
- delete_local_events,
- )
+ rel_type = relation.get("rel_type")
+ if rel_type not in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
+ # Unknown relation type
+ return
- def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
- token = RoomStreamToken.parse(token_str)
-
- # Tables that should be pruned:
- # event_auth
- # event_backward_extremities
- # event_edges
- # event_forward_extremities
- # event_json
- # event_push_actions
- # event_reference_hashes
- # event_search
- # event_to_state_groups
- # events
- # rejections
- # room_depth
- # state_groups
- # state_groups_state
-
- # we will build a temporary table listing the events so that we don't
- # have to keep shovelling the list back and forth across the
- # connection. Annoyingly the python sqlite driver commits the
- # transaction on CREATE, so let's do this first.
- #
- # furthermore, we might already have the table from a previous (failed)
- # purge attempt, so let's drop the table first.
+ parent_id = relation.get("event_id")
+ if not parent_id:
+ # Invalid relation
+ return
- txn.execute("DROP TABLE IF EXISTS events_to_purge")
+ aggregation_key = relation.get("key")
- txn.execute(
- "CREATE TEMPORARY TABLE events_to_purge ("
- " event_id TEXT NOT NULL,"
- " should_delete BOOLEAN NOT NULL"
- ")"
+ self.db.simple_insert_txn(
+ txn,
+ table="event_relations",
+ values={
+ "event_id": event.event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ "aggregation_key": aggregation_key,
+ },
)
- # First ensure that we're not about to delete all the forward extremeties
- txn.execute(
- "SELECT e.event_id, e.depth FROM events as e "
- "INNER JOIN event_forward_extremities as f "
- "ON e.event_id = f.event_id "
- "AND e.room_id = f.room_id "
- "WHERE f.room_id = ?",
- (room_id,),
+ txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,))
+ txn.call_after(
+ self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
)
- rows = txn.fetchall()
- max_depth = max(row[1] for row in rows)
-
- if max_depth < token.topological:
- # We need to ensure we don't delete all the events from the database
- # otherwise we wouldn't be able to send any events (due to not
- # having any backwards extremeties)
- raise SynapseError(
- 400, "topological_ordering is greater than forward extremeties"
- )
- logger.info("[purge] looking for events to delete")
+ if rel_type == RelationTypes.REPLACE:
+ txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
- should_delete_expr = "state_key IS NULL"
- should_delete_params = ()
- if not delete_local_events:
- should_delete_expr += " AND event_id NOT LIKE ?"
+ def _handle_redaction(self, txn, redacted_event_id):
+ """Handles receiving a redaction and checking whether we need to remove
+ any redacted relations from the database.
- # We include the parameter twice since we use the expression twice
- should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
-
- should_delete_params += (room_id, token.topological)
-
- # Note that we insert events that are outliers and aren't going to be
- # deleted, as nothing will happen to them.
- txn.execute(
- "INSERT INTO events_to_purge"
- " SELECT event_id, %s"
- " FROM events AS e LEFT JOIN state_events USING (event_id)"
- " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
- % (should_delete_expr, should_delete_expr),
- should_delete_params,
- )
-
- # We create the indices *after* insertion as that's a lot faster.
+ Args:
+ txn
+ redacted_event_id (str): The event that was redacted.
+ """
- # create an index on should_delete because later we'll be looking for
- # the should_delete / shouldn't_delete subsets
- txn.execute(
- "CREATE INDEX events_to_purge_should_delete"
- " ON events_to_purge(should_delete)"
+ self.db.simple_delete_txn(
+ txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
- # We do joins against events_to_purge for e.g. calculating state
- # groups to purge, etc., so lets make an index.
- txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
-
- txn.execute("SELECT event_id, should_delete FROM events_to_purge")
- event_rows = txn.fetchall()
- logger.info(
- "[purge] found %i events before cutoff, of which %i can be deleted",
- len(event_rows),
- sum(1 for e in event_rows if e[1]),
- )
+ def _store_room_topic_txn(self, txn, event):
+ if hasattr(event, "content") and "topic" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.topic", event.content["topic"]
+ )
- logger.info("[purge] Finding new backward extremities")
+ def _store_room_name_txn(self, txn, event):
+ if hasattr(event, "content") and "name" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.name", event.content["name"]
+ )
- # We calculate the new entries for the backward extremeties by finding
- # events to be purged that are pointed to by events we're not going to
- # purge.
- txn.execute(
- "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
- " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
- " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
- " WHERE ep2.event_id IS NULL"
- )
- new_backwards_extrems = txn.fetchall()
+ def _store_room_message_txn(self, txn, event):
+ if hasattr(event, "content") and "body" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.body", event.content["body"]
+ )
- logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+ def _store_retention_policy_for_room_txn(self, txn, event):
+ if hasattr(event, "content") and (
+ "min_lifetime" in event.content or "max_lifetime" in event.content
+ ):
+ if (
+ "min_lifetime" in event.content
+ and not isinstance(event.content.get("min_lifetime"), integer_types)
+ ) or (
+ "max_lifetime" in event.content
+ and not isinstance(event.content.get("max_lifetime"), integer_types)
+ ):
+ # Ignore the event if one of the value isn't an integer.
+ return
- txn.execute(
- "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
- )
+ self.db.simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ "min_lifetime": event.content.get("min_lifetime"),
+ "max_lifetime": event.content.get("max_lifetime"),
+ },
+ )
- # Update backward extremeties
- txn.executemany(
- "INSERT INTO event_backward_extremities (room_id, event_id)"
- " VALUES (?, ?)",
- [(room_id, event_id) for event_id, in new_backwards_extrems],
- )
+ self.store._invalidate_cache_and_stream(
+ txn, self.store.get_retention_policy_for_room, (event.room_id,)
+ )
- logger.info("[purge] finding state groups referenced by deleted events")
+ def store_event_search_txn(self, txn, event, key, value):
+ """Add event to the search table
- # Get all state groups that are referenced by events that are to be
- # deleted.
- txn.execute(
- """
- SELECT DISTINCT state_group FROM events_to_purge
- INNER JOIN event_to_state_groups USING (event_id)
+ Args:
+ txn (cursor):
+ event (EventBase):
+ key (str):
+ value (str):
"""
+ self.store.store_search_entries_txn(
+ txn,
+ (
+ SearchEntry(
+ key=key,
+ value=value,
+ event_id=event.event_id,
+ room_id=event.room_id,
+ stream_ordering=event.internal_metadata.stream_ordering,
+ origin_server_ts=event.origin_server_ts,
+ ),
+ ),
)
- referenced_state_groups = {sg for sg, in txn}
- logger.info(
- "[purge] found %i referenced state groups", len(referenced_state_groups)
- )
+ def _set_push_actions_for_event_and_users_txn(
+ self, txn, events_and_contexts, all_events_and_contexts
+ ):
+ """Handles moving push actions from staging table to main
+ event_push_actions table for all events in `events_and_contexts`.
- logger.info("[purge] removing events from event_to_state_groups")
- txn.execute(
- "DELETE FROM event_to_state_groups "
- "WHERE event_id IN (SELECT event_id from events_to_purge)"
- )
- for event_id, _ in event_rows:
- txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+ Also ensures that all events in `all_events_and_contexts` are removed
+ from the push action staging area.
- # Delete all remote non-state events
- for table in (
- "events",
- "event_json",
- "event_auth",
- "event_edges",
- "event_forward_extremities",
- "event_reference_hashes",
- "event_search",
- "rejections",
- ):
- logger.info("[purge] removing events from %s", table)
+ Args:
+ events_and_contexts (list[(EventBase, EventContext)]): events
+ we are persisting
+ all_events_and_contexts (list[(EventBase, EventContext)]): all
+ events that we were going to persist. This includes events
+ we've already persisted, etc, that wouldn't appear in
+ events_and_context.
+ """
- txn.execute(
- "DELETE FROM %s WHERE event_id IN ("
- " SELECT event_id FROM events_to_purge WHERE should_delete"
- ")" % (table,)
+ sql = """
+ INSERT INTO event_push_actions (
+ room_id, event_id, user_id, actions, stream_ordering,
+ topological_ordering, notif, highlight
)
+ SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
+ FROM event_push_actions_staging
+ WHERE event_id = ?
+ """
- # event_push_actions lacks an index on event_id, and has one on
- # (room_id, event_id) instead.
- for table in ("event_push_actions",):
- logger.info("[purge] removing events from %s", table)
+ if events_and_contexts:
+ txn.executemany(
+ sql,
+ (
+ (
+ event.room_id,
+ event.internal_metadata.stream_ordering,
+ event.depth,
+ event.event_id,
+ )
+ for event, _ in events_and_contexts
+ ),
+ )
- txn.execute(
- "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
- " SELECT event_id FROM events_to_purge WHERE should_delete"
- ")" % (table,),
- (room_id,),
+ for event, _ in events_and_contexts:
+ user_ids = self.db.simple_select_onecol_txn(
+ txn,
+ table="event_push_actions_staging",
+ keyvalues={"event_id": event.event_id},
+ retcol="user_id",
)
- # Mark all state and own events as outliers
- logger.info("[purge] marking remaining events as outliers")
- txn.execute(
- "UPDATE events SET outlier = ?"
- " WHERE event_id IN ("
- " SELECT event_id FROM events_to_purge "
- " WHERE NOT should_delete"
- ")",
- (True,),
+ for uid in user_ids:
+ txn.call_after(
+ self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
+ (event.room_id, uid),
+ )
+
+ # Now we delete the staging area for *all* events that were being
+ # persisted.
+ txn.executemany(
+ "DELETE FROM event_push_actions_staging WHERE event_id = ?",
+ ((event.event_id,) for event, _ in all_events_and_contexts),
)
- # synapse tries to take out an exclusive lock on room_depth whenever it
- # persists events (because upsert), and once we run this update, we
- # will block that for the rest of our transaction.
- #
- # So, let's stick it at the end so that we don't block event
- # persistence.
- #
- # We do this by calculating the minimum depth of the backwards
- # extremities. However, the events in event_backward_extremities
- # are ones we don't have yet so we need to look at the events that
- # point to it via event_edges table.
- txn.execute(
- """
- SELECT COALESCE(MIN(depth), 0)
- FROM event_backward_extremities AS eb
- INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
- INNER JOIN events AS e ON e.event_id = eg.event_id
- WHERE eb.room_id = ?
- """,
+ def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
+ # Sad that we have to blow away the cache for the whole room here
+ txn.call_after(
+ self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
(room_id,),
)
- (min_depth,) = txn.fetchone()
-
- logger.info("[purge] updating room_depth to %d", min_depth)
-
txn.execute(
- "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
- (min_depth, room_id),
+ "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
+ (room_id, event_id),
)
- # finally, drop the temp table. this will commit the txn in sqlite,
- # so make sure to keep this actually last.
- txn.execute("DROP TABLE events_to_purge")
-
- logger.info("[purge] done")
-
- return referenced_state_groups
-
- def purge_room(self, room_id):
- """Deletes all record of a room
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[List[int]]: The list of state groups to delete.
- """
-
- return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
-
- def _purge_room_txn(self, txn, room_id):
- # First we fetch all the state groups that should be deleted, before
- # we delete that information.
- txn.execute(
- """
- SELECT DISTINCT state_group FROM events
- INNER JOIN event_to_state_groups USING(event_id)
- WHERE events.room_id = ?
- """,
- (room_id,),
+ def _store_rejections_txn(self, txn, event_id, reason):
+ self.db.simple_insert_txn(
+ txn,
+ table="rejections",
+ values={
+ "event_id": event_id,
+ "reason": reason,
+ "last_check": self._clock.time_msec(),
+ },
)
- state_groups = [row[0] for row in txn]
-
- # Now we delete tables which lack an index on room_id but have one on event_id
- for table in (
- "event_auth",
- "event_edges",
- "event_push_actions_staging",
- "event_reference_hashes",
- "event_relations",
- "event_to_state_groups",
- "redactions",
- "rejections",
- "state_events",
- ):
- logger.info("[purge] removing %s from %s", room_id, table)
-
- txn.execute(
- """
- DELETE FROM %s WHERE event_id IN (
- SELECT event_id FROM events WHERE room_id=?
- )
- """
- % (table,),
- (room_id,),
- )
-
- # and finally, the tables with an index on room_id (or no useful index)
- for table in (
- "current_state_events",
- "event_backward_extremities",
- "event_forward_extremities",
- "event_json",
- "event_push_actions",
- "event_search",
- "events",
- "group_rooms",
- "public_room_list_stream",
- "receipts_graph",
- "receipts_linearized",
- "room_aliases",
- "room_depth",
- "room_memberships",
- "room_stats_state",
- "room_stats_current",
- "room_stats_historical",
- "room_stats_earliest_token",
- "rooms",
- "stream_ordering_to_exterm",
- "users_in_public_rooms",
- "users_who_share_private_rooms",
- # no useful index, but let's clear them anyway
- "appservice_room_list",
- "e2e_room_keys",
- "event_push_summary",
- "pusher_throttle",
- "group_summary_rooms",
- "local_invites",
- "room_account_data",
- "room_tags",
- "local_current_membership",
- ):
- logger.info("[purge] removing %s from %s", room_id, table)
- txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
-
- # Other tables we do NOT need to clear out:
- #
- # - blocked_rooms
- # This is important, to make sure that we don't accidentally rejoin a blocked
- # room after it was purged
- #
- # - user_directory
- # This has a room_id column, but it is unused
- #
-
- # Other tables that we might want to consider clearing out include:
- #
- # - event_reports
- # Given that these are intended for abuse management my initial
- # inclination is to leave them in place.
- #
- # - current_state_delta_stream
- # - ex_outlier_stream
- # - room_tags_revisions
- # The problem with these is that they are largeish and there is no room_id
- # index on them. In any case we should be clearing out 'stream' tables
- # periodically anyway (#5888)
-
- # TODO: we could probably usefully do a bunch of cache invalidation here
-
- logger.info("[purge] done")
+ def _store_event_state_mappings_txn(
+ self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
+ ):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
- return state_groups
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.state_group_before_event
+ continue
- async def is_event_after(self, event_id1, event_id2):
- """Returns True if event_id1 is after event_id2 in the stream
- """
- to_1, so_1 = await self._get_event_ordering(event_id1)
- to_2, so_2 = await self._get_event_ordering(event_id2)
- return (to_1, so_1) > (to_2, so_2)
+ state_groups[event.event_id] = context.state_group
- @cachedInlineCallbacks(max_entries=5000)
- def _get_event_ordering(self, event_id):
- res = yield self.db.simple_select_one(
- table="events",
- retcols=["topological_ordering", "stream_ordering"],
- keyvalues={"event_id": event_id},
- allow_none=True,
+ self.db.simple_insert_many_txn(
+ txn,
+ table="event_to_state_groups",
+ values=[
+ {"state_group": state_group_id, "event_id": event_id}
+ for event_id, state_group_id in iteritems(state_groups)
+ ],
)
- if not res:
- raise SynapseError(404, "Could not find event %s" % (event_id,))
+ for event_id, state_group_id in iteritems(state_groups):
+ txn.call_after(
+ self.store._get_state_group_for_event.prefill,
+ (event_id,),
+ state_group_id,
+ )
- return (int(res["topological_ordering"]), int(res["stream_ordering"]))
+ def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ min_depth = self.store._get_min_depth_interaction(txn, room_id)
- def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
- def get_all_updated_current_state_deltas_txn(txn):
- sql = """
- SELECT stream_id, room_id, type, state_key, event_id
- FROM current_state_delta_stream
- WHERE ? < stream_id AND stream_id <= ?
- ORDER BY stream_id ASC LIMIT ?
- """
- txn.execute(sql, (from_token, to_token, limit))
- return txn.fetchall()
+ if min_depth is not None and depth >= min_depth:
+ return
- return self.db.runInteraction(
- "get_all_updated_current_state_deltas",
- get_all_updated_current_state_deltas_txn,
+ self.db.simple_upsert_txn(
+ txn,
+ table="room_depth",
+ keyvalues={"room_id": room_id},
+ values={"min_depth": depth},
)
- def insert_labels_for_event_txn(
- self, txn, event_id, labels, room_id, topological_ordering
- ):
- """Store the mapping between an event's ID and its labels, with one row per
- (event_id, label) tuple.
-
- Args:
- txn (LoggingTransaction): The transaction to execute.
- event_id (str): The event's ID.
- labels (list[str]): A list of text labels.
- room_id (str): The ID of the room the event was sent to.
- topological_ordering (int): The position of the event in the room's topology.
+ def _handle_mult_prev_events(self, txn, events):
"""
- return self.db.simple_insert_many_txn(
- txn=txn,
- table="event_labels",
+ For the given event, update the event edges table and forward and
+ backward extremities tables.
+ """
+ self.db.simple_insert_many_txn(
+ txn,
+ table="event_edges",
values=[
{
- "event_id": event_id,
- "label": label,
- "room_id": room_id,
- "topological_ordering": topological_ordering,
+ "event_id": ev.event_id,
+ "prev_event_id": e_id,
+ "room_id": ev.room_id,
+ "is_state": False,
}
- for label in labels
+ for ev in events
+ for e_id in ev.prev_event_ids()
],
)
- def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
- """Save the expiry timestamp associated with a given event ID.
+ self._update_backward_extremeties(txn, events)
- Args:
- txn (LoggingTransaction): The database transaction to use.
- event_id (str): The event ID the expiry timestamp is associated with.
- expiry_ts (int): The timestamp at which to expire (delete) the event.
- """
- return self.db.simple_insert_txn(
- txn=txn,
- table="event_expiry",
- values={"event_id": event_id, "expiry_ts": expiry_ts},
- )
+ def _update_backward_extremeties(self, txn, events):
+ """Updates the event_backward_extremities tables based on the new/updated
+ events being persisted.
- @defer.inlineCallbacks
- def expire_event(self, event_id):
- """Retrieve and expire an event that has expired, and delete its associated
- expiry timestamp. If the event can't be retrieved, delete its associated
- timestamp so we don't try to expire it again in the future.
+ This is called for new events *and* for events that were outliers, but
+ are now being persisted as non-outliers.
- Args:
- event_id (str): The ID of the event to delete.
+ Forward extremities are handled when we first start persisting the events.
"""
- # Try to retrieve the event's content from the database or the event cache.
- event = yield self.get_event(event_id)
-
- def delete_expired_event_txn(txn):
- # Delete the expiry timestamp associated with this event from the database.
- self._delete_event_expiry_txn(txn, event_id)
-
- if not event:
- # If we can't find the event, log a warning and delete the expiry date
- # from the database so that we don't try to expire it again in the
- # future.
- logger.warning(
- "Can't expire event %s because we don't have it.", event_id
- )
- return
-
- # Prune the event's dict then convert it to JSON.
- pruned_json = encode_json(
- prune_event_dict(event.room_version, event.get_dict())
- )
-
- # Update the event_json table to replace the event's JSON with the pruned
- # JSON.
- self._censor_event_txn(txn, event.event_id, pruned_json)
-
- # We need to invalidate the event cache entry for this event because we
- # changed its content in the database. We can't call
- # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
- # right type.
- txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
- # Send that invalidation to replication so that other workers also invalidate
- # the event cache.
- self._send_invalidation_to_replication(
- txn, "_get_event_cache", (event.event_id,)
- )
-
- yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
-
- def _delete_event_expiry_txn(self, txn, event_id):
- """Delete the expiry timestamp associated with an event ID without deleting the
- actual event.
+ events_by_room = {}
+ for ev in events:
+ events_by_room.setdefault(ev.room_id, []).append(ev)
+
+ query = (
+ "INSERT INTO event_backward_extremities (event_id, room_id)"
+ " SELECT ?, ? WHERE NOT EXISTS ("
+ " SELECT 1 FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ " )"
+ " AND NOT EXISTS ("
+ " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+ " AND outlier = ?"
+ " )"
+ )
- Args:
- txn (LoggingTransaction): The transaction to use to perform the deletion.
- event_id (str): The event ID to delete the associated expiry timestamp of.
- """
- return self.db.simple_delete_txn(
- txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+ txn.executemany(
+ query,
+ [
+ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
+ for ev in events
+ for e_id in ev.prev_event_ids()
+ if not ev.internal_metadata.is_outlier()
+ ],
)
- def get_next_event_to_expire(self):
- """Retrieve the entry with the lowest expiry timestamp in the event_expiry
- table, or None if there's no more event to expire.
+ query = (
+ "DELETE FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
+ )
+ txn.executemany(
+ query,
+ [
+ (ev.event_id, ev.room_id)
+ for ev in events
+ if not ev.internal_metadata.is_outlier()
+ ],
+ )
- Returns: Deferred[Optional[Tuple[str, int]]]
- A tuple containing the event ID as its first element and an expiry timestamp
- as its second one, if there's at least one row in the event_expiry table.
- None otherwise.
+ async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+ """Mark the invite has having been rejected even though we failed to
+ create a leave event for it.
"""
- def get_next_event_to_expire_txn(txn):
- txn.execute(
- """
- SELECT event_id, expiry_ts FROM event_expiry
- ORDER BY expiry_ts ASC LIMIT 1
- """
- )
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
- return txn.fetchone()
+ def f(txn, stream_ordering):
+ txn.execute(sql, (stream_ordering, True, room_id, user_id))
- return self.db.runInteraction(
- desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
- )
+ # We also clear this entry from `local_current_membership`.
+ # Ideally we'd point to a leave event, but we don't have one, so
+ # nevermind.
+ self.db.simple_delete_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ )
+ with self._stream_id_gen.get_next() as stream_ordering:
+ await self.db.runInteraction("locally_reject_invite", f, stream_ordering)
-AllNewEventsResult = namedtuple(
- "AllNewEventsResult",
- [
- "new_forward_events",
- "new_backfill_events",
- "forward_ex_outliers",
- "backward_ex_outliers",
- ],
-)
+ return stream_ordering
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index ca237c6f12..213d69100a 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,7 +19,7 @@ import itertools
import logging
import threading
from collections import namedtuple
-from typing import List, Optional
+from typing import List, Optional, Tuple
from canonicaljson import json
from constantly import NamedConstant, Names
@@ -27,7 +27,7 @@ from constantly import NamedConstant, Names
from twisted.internet import defer
from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
@@ -35,12 +35,14 @@ from synapse.api.room_versions import (
)
from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event
-from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.database import Database
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache
+from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -74,14 +76,50 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
+ if hs.config.worker.writers.events == hs.get_instance_name():
+ # We are the process in charge of generating stream ids for events,
+ # so instantiate ID generators based on the database
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ extra_tables=[("local_invites", "stream_id")],
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ )
+ else:
+ # Another process is in charge of persisting events and generating
+ # stream IDs: rely on the replication streams to let us know which
+ # IDs we can process.
+ self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
+ self._backfill_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering", step=-1
+ )
+
self._get_event_cache = Cache(
- "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+ "*getEvent*",
+ keylen=3,
+ max_entries=hs.config.caches.event_cache_size,
+ apply_cache_factor_from_config=False,
)
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
+ if stream_name == "events":
+ self._stream_id_gen.advance(token)
+ elif stream_name == "backfill":
+ self._backfill_id_gen.advance(-token)
+
+ super().process_replication_rows(stream_name, instance_name, token, rows)
+
def get_received_ts(self, event_id):
"""Get received_ts (when it was persisted) for the event.
@@ -409,7 +447,7 @@ class EventsWorkerStore(SQLBaseStore):
missing_events_ids = [e for e in event_ids if e not in event_entry_map]
if missing_events_ids:
- log_ctx = LoggingContext.current_context()
+ log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
# Note that _get_events_from_db is also responsible for turning db rows
@@ -632,7 +670,7 @@ class EventsWorkerStore(SQLBaseStore):
event_map[event_id] = original_ev
- # finally, we can decide whether each one nededs redacting, and build
+ # finally, we can decide whether each one needs redacting, and build
# the cache entries.
result_map = {}
for event_id, original_ev in event_map.items():
@@ -963,3 +1001,343 @@ class EventsWorkerStore(SQLBaseStore):
complexity_v1 = round(state_events / 500, 2)
return {"v1": complexity_v1}
+
+ def get_current_backfill_token(self):
+ """The current minimum token that backfilled events have reached"""
+ return -self._backfill_id_gen.get_current_token()
+
+ def get_current_events_token(self):
+ """The current maximum token that events have reached"""
+ return self._stream_id_gen.get_current_token()
+
+ def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+ """Returns new events, for the Events replication stream
+
+ Args:
+ last_id: the last stream_id from the previous batch.
+ current_id: the maximum stream_id to return up to
+ limit: the maximum number of rows to return
+
+ Returns: Deferred[List[Tuple]]
+ a list of events stream rows. Each tuple consists of a stream id as
+ the first element, followed by fields suitable for casting into an
+ EventsStreamRow.
+ """
+
+ def get_all_new_forward_event_rows(txn):
+ sql = (
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+ )
+
+ def get_ex_outlier_stream_rows(self, last_id, current_id):
+ """Returns de-outliered events, for the Events replication stream
+
+ Args:
+ last_id: the last stream_id from the previous batch.
+ current_id: the maximum stream_id to return up to
+
+ Returns: Deferred[List[Tuple]]
+ a list of events stream rows. Each tuple consists of a stream id as
+ the first element, followed by fields suitable for casting into an
+ EventsStreamRow.
+ """
+
+ def get_ex_outlier_stream_rows_txn(txn):
+ sql = (
+ "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? < event_stream_ordering"
+ " AND event_stream_ordering <= ?"
+ " ORDER BY event_stream_ordering ASC"
+ )
+
+ txn.execute(sql, (last_id, current_id))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
+ )
+
+ def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_new_backfill_event_rows(txn):
+ sql = (
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ txn.execute(sql, (-last_id, -current_id, limit))
+ new_event_updates = txn.fetchall()
+
+ if len(new_event_updates) == limit:
+ upper_bound = new_event_updates[-1][0]
+ else:
+ upper_bound = current_id
+
+ sql = (
+ "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts, relates_to_id"
+ " FROM events AS e"
+ " INNER JOIN ex_outlier_stream USING (event_id)"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
+ " WHERE ? > event_stream_ordering"
+ " AND event_stream_ordering >= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (-last_id, -upper_bound))
+ new_event_updates.extend(txn.fetchall())
+
+ return new_event_updates
+
+ return self.db.runInteraction(
+ "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+ )
+
+ async def get_all_updated_current_state_deltas(
+ self, from_token: int, to_token: int, target_row_count: int
+ ) -> Tuple[List[Tuple], int, bool]:
+ """Fetch updates from current_state_delta_stream
+
+ Args:
+ from_token: The previous stream token. Updates from this stream id will
+ be excluded.
+
+ to_token: The current stream token (ie the upper limit). Updates up to this
+ stream id will be included (modulo the 'limit' param)
+
+ target_row_count: The number of rows to try to return. If more rows are
+ available, we will set 'limited' in the result. In the event of a large
+ batch, we may return more rows than this.
+ Returns:
+ A triplet `(updates, new_last_token, limited)`, where:
+ * `updates` is a list of database tuples.
+ * `new_last_token` is the new position in stream.
+ * `limited` is whether there are more updates to fetch.
+ """
+
+ def get_all_updated_current_state_deltas_txn(txn):
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id
+ FROM current_state_delta_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, target_row_count))
+ return txn.fetchall()
+
+ def get_deltas_for_stream_id_txn(txn, stream_id):
+ sql = """
+ SELECT stream_id, room_id, type, state_key, event_id
+ FROM current_state_delta_stream
+ WHERE stream_id = ?
+ """
+ txn.execute(sql, [stream_id])
+ return txn.fetchall()
+
+ # we need to make sure that, for every stream id in the results, we get *all*
+ # the rows with that stream id.
+
+ rows = await self.db.runInteraction(
+ "get_all_updated_current_state_deltas",
+ get_all_updated_current_state_deltas_txn,
+ ) # type: List[Tuple]
+
+ # if we've got fewer rows than the limit, we're good
+ if len(rows) < target_row_count:
+ return rows, to_token, False
+
+ # we hit the limit, so reduce the upper limit so that we exclude the stream id
+ # of the last row in the result.
+ assert rows[-1][0] <= to_token
+ to_token = rows[-1][0] - 1
+
+ # search backwards through the list for the point to truncate
+ for idx in range(len(rows) - 1, 0, -1):
+ if rows[idx - 1][0] <= to_token:
+ return rows[:idx], to_token, True
+
+ # bother. We didn't get a full set of changes for even a single
+ # stream id. let's run the query again, without a row limit, but for
+ # just one stream id.
+ to_token += 1
+ rows = await self.db.runInteraction(
+ "get_deltas_for_stream_id", get_deltas_for_stream_id_txn, to_token
+ )
+
+ return rows, to_token, True
+
+ @cached(num_args=5, max_entries=10)
+ def get_all_new_events(
+ self,
+ last_backfill_id,
+ last_forward_id,
+ current_backfill_id,
+ current_forward_id,
+ limit,
+ ):
+ """Get all the new events that have arrived at the server either as
+ new events or as backfilled events"""
+ have_backfill_events = last_backfill_id != current_backfill_id
+ have_forward_events = last_forward_id != current_forward_id
+
+ if not have_backfill_events and not have_forward_events:
+ return defer.succeed(AllNewEventsResult([], [], [], [], []))
+
+ def get_all_new_events_txn(txn):
+ sql = (
+ "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? < stream_ordering AND stream_ordering <= ?"
+ " ORDER BY stream_ordering ASC"
+ " LIMIT ?"
+ )
+ if have_forward_events:
+ txn.execute(sql, (last_forward_id, current_forward_id, limit))
+ new_forward_events = txn.fetchall()
+
+ if len(new_forward_events) == limit:
+ upper_bound = new_forward_events[-1][0]
+ else:
+ upper_bound = current_forward_id
+
+ sql = (
+ "SELECT event_stream_ordering, event_id, state_group"
+ " FROM ex_outlier_stream"
+ " WHERE ? > event_stream_ordering"
+ " AND event_stream_ordering >= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (last_forward_id, upper_bound))
+ forward_ex_outliers = txn.fetchall()
+ else:
+ new_forward_events = []
+ forward_ex_outliers = []
+
+ sql = (
+ "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+ " state_key, redacts"
+ " FROM events AS e"
+ " LEFT JOIN redactions USING (event_id)"
+ " LEFT JOIN state_events USING (event_id)"
+ " WHERE ? > stream_ordering AND stream_ordering >= ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT ?"
+ )
+ if have_backfill_events:
+ txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
+ new_backfill_events = txn.fetchall()
+
+ if len(new_backfill_events) == limit:
+ upper_bound = new_backfill_events[-1][0]
+ else:
+ upper_bound = current_backfill_id
+
+ sql = (
+ "SELECT -event_stream_ordering, event_id, state_group"
+ " FROM ex_outlier_stream"
+ " WHERE ? > event_stream_ordering"
+ " AND event_stream_ordering >= ?"
+ " ORDER BY event_stream_ordering DESC"
+ )
+ txn.execute(sql, (-last_backfill_id, -upper_bound))
+ backward_ex_outliers = txn.fetchall()
+ else:
+ new_backfill_events = []
+ backward_ex_outliers = []
+
+ return AllNewEventsResult(
+ new_forward_events,
+ new_backfill_events,
+ forward_ex_outliers,
+ backward_ex_outliers,
+ )
+
+ return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
+
+ async def is_event_after(self, event_id1, event_id2):
+ """Returns True if event_id1 is after event_id2 in the stream
+ """
+ to_1, so_1 = await self.get_event_ordering(event_id1)
+ to_2, so_2 = await self.get_event_ordering(event_id2)
+ return (to_1, so_1) > (to_2, so_2)
+
+ @cachedInlineCallbacks(max_entries=5000)
+ def get_event_ordering(self, event_id):
+ res = yield self.db.simple_select_one(
+ table="events",
+ retcols=["topological_ordering", "stream_ordering"],
+ keyvalues={"event_id": event_id},
+ allow_none=True,
+ )
+
+ if not res:
+ raise SynapseError(404, "Could not find event %s" % (event_id,))
+
+ return (int(res["topological_ordering"]), int(res["stream_ordering"]))
+
+ def get_next_event_to_expire(self):
+ """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+ table, or None if there's no more event to expire.
+
+ Returns: Deferred[Optional[Tuple[str, int]]]
+ A tuple containing the event ID as its first element and an expiry timestamp
+ as its second one, if there's at least one row in the event_expiry table.
+ None otherwise.
+ """
+
+ def get_next_event_to_expire_txn(txn):
+ txn.execute(
+ """
+ SELECT event_id, expiry_ts FROM event_expiry
+ ORDER BY expiry_ts ASC LIMIT 1
+ """
+ )
+
+ return txn.fetchone()
+
+ return self.db.runInteraction(
+ desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+ )
+
+
+AllNewEventsResult = namedtuple(
+ "AllNewEventsResult",
+ [
+ "new_forward_events",
+ "new_backfill_events",
+ "forward_ex_outliers",
+ "backward_ex_outliers",
+ ],
+)
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 0963e6c250..fb1361f1c1 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -68,24 +68,78 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_invited_users_in_group",
)
- def get_rooms_in_group(self, group_id, include_private=False):
+ def get_rooms_in_group(self, group_id: str, include_private: bool = False):
+ """Retrieve the rooms that belong to a given group. Does not return rooms that
+ lack members.
+
+ Args:
+ group_id: The ID of the group to query for rooms
+ include_private: Whether to return private rooms in results
+
+ Returns:
+ Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
+ form of:
+
+ {
+ "room_id": "!a_room_id:example.com", # The ID of the room
+ "is_public": False # Whether this is a public room or not
+ }
+ """
# TODO: Pagination
- keyvalues = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
+ def _get_rooms_in_group_txn(txn):
+ sql = """
+ SELECT room_id, is_public FROM group_rooms
+ WHERE group_id = ?
+ AND room_id IN (
+ SELECT group_rooms.room_id FROM group_rooms
+ LEFT JOIN room_stats_current ON
+ group_rooms.room_id = room_stats_current.room_id
+ AND joined_members > 0
+ AND local_users_in_room > 0
+ LEFT JOIN rooms ON
+ group_rooms.room_id = rooms.room_id
+ AND (room_version <> '') = ?
+ )
+ """
+ args = [group_id, False]
- return self.db.simple_select_list(
- table="group_rooms",
- keyvalues=keyvalues,
- retcols=("room_id", "is_public"),
- desc="get_rooms_in_group",
- )
+ if not include_private:
+ sql += " AND is_public = ?"
+ args += [True]
+
+ txn.execute(sql, args)
+
+ return [
+ {"room_id": room_id, "is_public": is_public}
+ for room_id, is_public in txn
+ ]
- def get_rooms_for_summary_by_category(self, group_id, include_private=False):
+ return self.db.runInteraction("get_rooms_in_group", _get_rooms_in_group_txn)
+
+ def get_rooms_for_summary_by_category(
+ self, group_id: str, include_private: bool = False,
+ ):
"""Get the rooms and categories that should be included in a summary request
- Returns ([rooms], [categories])
+ Args:
+ group_id: The ID of the group to query the summary for
+ include_private: Whether to return private rooms in results
+
+ Returns:
+ Deferred[Tuple[List, Dict]]: A tuple containing:
+
+ * A list of dictionaries with the keys:
+ * "room_id": str, the room ID
+ * "is_public": bool, whether the room is public
+ * "category_id": str|None, the category ID if set, else None
+ * "order": int, the sort order of rooms
+
+ * A dictionary with the key:
+ * category_id (str): a dictionary with the keys:
+ * "is_public": bool, whether the category is public
+ * "profile": str, the category profile
+ * "order": int, the sort order of rooms in this category
"""
def _get_rooms_for_summary_txn(txn):
@@ -97,13 +151,23 @@ class GroupServerWorkerStore(SQLBaseStore):
SELECT room_id, is_public, category_id, room_order
FROM group_summary_rooms
WHERE group_id = ?
+ AND room_id IN (
+ SELECT group_rooms.room_id FROM group_rooms
+ LEFT JOIN room_stats_current ON
+ group_rooms.room_id = room_stats_current.room_id
+ AND joined_members > 0
+ AND local_users_in_room > 0
+ LEFT JOIN rooms ON
+ group_rooms.room_id = rooms.room_id
+ AND (room_version <> '') = ?
+ )
"""
if not include_private:
sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
+ txn.execute(sql, (group_id, False, True))
else:
- txn.execute(sql, (group_id,))
+ txn.execute(sql, (group_id, False))
rooms = [
{
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
index ba89c68c9f..4e1642a27a 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/data_stores/main/keys.py
@@ -17,8 +17,6 @@
import itertools
import logging
-import six
-
from signedjson.key import decode_verify_key_bytes
from synapse.storage._base import SQLBaseStore
@@ -28,12 +26,8 @@ from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
+
+db_binary_type = memoryview
class KeyStore(SQLBaseStore):
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 80ca36dedf..8aecd414c2 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -340,7 +340,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_expired_url_cache", _get_expired_url_cache_txn
)
- def delete_url_cache(self, media_ids):
+ async def delete_url_cache(self, media_ids):
if len(media_ids) == 0:
return
@@ -349,7 +349,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
+ return await self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
def get_url_cache_media_before(self, before_ts):
sql = (
@@ -367,7 +367,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
- def delete_url_cache_media(self, media_ids):
+ async def delete_url_cache_media(self, media_ids):
if len(media_ids) == 0:
return
@@ -380,6 +380,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)
diff --git a/synapse/storage/data_stores/main/metrics.py b/synapse/storage/data_stores/main/metrics.py
new file mode 100644
index 0000000000..dad5bbc602
--- /dev/null
+++ b/synapse/storage/data_stores/main/metrics.py
@@ -0,0 +1,128 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import typing
+from collections import Counter
+
+from twisted.internet import defer
+
+from synapse.metrics import BucketCollector
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.event_push_actions import (
+ EventPushActionsWorkerStore,
+)
+from synapse.storage.database import Database
+
+
+class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
+ """Functions to pull various metrics from the DB, for e.g. phone home
+ stats and prometheus metrics.
+ """
+
+ def __init__(self, database: Database, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ # Collect metrics on the number of forward extremities that exist.
+ # Counter of number of extremities to count
+ self._current_forward_extremities_amount = (
+ Counter()
+ ) # type: typing.Counter[int]
+
+ BucketCollector(
+ "synapse_forward_extremities",
+ lambda: self._current_forward_extremities_amount,
+ buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
+ )
+
+ # Read the extrems every 60 minutes
+ def read_forward_extremities():
+ # run as a background process to make sure that the database transactions
+ # have a logcontext to report to
+ return run_as_background_process(
+ "read_forward_extremities", self._read_forward_extremities
+ )
+
+ hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
+
+ async def _read_forward_extremities(self):
+ def fetch(txn):
+ txn.execute(
+ """
+ select count(*) c from event_forward_extremities
+ group by room_id
+ """
+ )
+ return txn.fetchall()
+
+ res = await self.db.runInteraction("read_forward_extremities", fetch)
+ self._current_forward_extremities_amount = Counter([x[0] for x in res])
+
+ @defer.inlineCallbacks
+ def count_daily_messages(self):
+ """
+ Returns an estimate of the number of messages sent in the last day.
+
+ If it has been significantly less or more than one day since the last
+ call to this function, it will return None.
+ """
+
+ def _count_messages(txn):
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.message'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ ret = yield self.db.runInteraction("count_messages", _count_messages)
+ return ret
+
+ @defer.inlineCallbacks
+ def count_daily_sent_messages(self):
+ def _count_messages(txn):
+ # This is good enough as if you have silly characters in your own
+ # hostname then thats your own fault.
+ like_clause = "%:" + self.hs.hostname
+
+ sql = """
+ SELECT COALESCE(COUNT(*), 0) FROM events
+ WHERE type = 'm.room.message'
+ AND sender LIKE ?
+ AND stream_ordering > ?
+ """
+
+ txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
+ (count,) = txn.fetchone()
+ return count
+
+ ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
+ return ret
+
+ @defer.inlineCallbacks
+ def count_daily_active_rooms(self):
+ def _count(txn):
+ sql = """
+ SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events
+ WHERE type = 'm.room.message'
+ AND stream_ordering > ?
+ """
+ txn.execute(sql, (self.stream_ordering_day_ago,))
+ (count,) = txn.fetchone()
+ return count
+
+ ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
+ return ret
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 925bc5691b..e459cf49a0 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -13,11 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import Database
+from synapse.storage.database import Database, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -77,20 +78,19 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return self.db.runInteraction("count_users_by_service", _count_users_by_service)
- @defer.inlineCallbacks
- def get_registered_reserved_users(self):
- """Of the reserved threepids defined in config, which are associated
- with registered users?
+ async def get_registered_reserved_users(self) -> List[str]:
+ """Of the reserved threepids defined in config, retrieve those that are associated
+ with registered users
Returns:
- Defered[list]: Real reserved users
+ User IDs of actual users that are reserved
"""
users = []
for tp in self.hs.config.mau_limits_reserved_threepids[
: self.hs.config.max_mau_value
]:
- user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
tp["medium"], tp["address"]
)
if user_id:
@@ -122,6 +122,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: Database, db_conn, hs):
super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
+ self._limit_usage_by_mau = hs.config.limit_usage_by_mau
+ self._mau_stats_only = hs.config.mau_stats_only
+ self._max_mau_value = hs.config.max_mau_value
+
# Do not add more reserved users than the total allowable number
# cur = LoggingTransaction(
self.db.new_transaction(
@@ -130,7 +134,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
[],
[],
self._initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids[: self.hs.config.max_mau_value],
+ hs.config.mau_limits_reserved_threepids[: self._max_mau_value],
)
def _initialise_reserved_users(self, txn, threepids):
@@ -142,6 +146,15 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
threepids (list[dict]): List of threepid dicts to reserve
"""
+ # XXX what is this function trying to achieve? It upserts into
+ # monthly_active_users for each *registered* reserved mau user, but why?
+ #
+ # - shouldn't there already be an entry for each reserved user (at least
+ # if they have been active recently)?
+ #
+ # - if it's important that the timestamp is kept up to date, why do we only
+ # run this at startup?
+
for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
@@ -158,13 +171,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
else:
logger.warning("mau limit reserved threepid %s not found in db" % tp)
- @defer.inlineCallbacks
- def reap_monthly_active_users(self):
+ async def reap_monthly_active_users(self):
"""Cleans out monthly active user table to ensure that no stale
entries exist.
-
- Returns:
- Deferred[]
"""
def _reap_users(txn, reserved_users):
@@ -174,76 +183,57 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
- query_args = [thirty_days_ago]
- base_sql = "DELETE FROM monthly_active_users WHERE timestamp < ?"
-
- # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
- # when len(reserved_users) == 0. Works fine on sqlite.
- if len(reserved_users) > 0:
- # questionmarks is a hack to overcome sqlite not supporting
- # tuples in 'WHERE IN %s'
- question_marks = ",".join("?" * len(reserved_users))
-
- query_args.extend(reserved_users)
- sql = base_sql + " AND user_id NOT IN ({})".format(question_marks)
- else:
- sql = base_sql
- txn.execute(sql, query_args)
+ in_clause, in_clause_args = make_in_list_sql_clause(
+ self.database_engine, "user_id", reserved_users
+ )
+
+ txn.execute(
+ "DELETE FROM monthly_active_users WHERE timestamp < ? AND NOT %s"
+ % (in_clause,),
+ [thirty_days_ago] + in_clause_args,
+ )
- max_mau_value = self.hs.config.max_mau_value
- if self.hs.config.limit_usage_by_mau:
+ if self._limit_usage_by_mau:
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
# Note it is not possible to write this query using OFFSET due to
# incompatibilities in how sqlite and postgres support the feature.
- # sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present
- # While Postgres does not require 'LIMIT', but also does not support
+ # Sqlite requires 'LIMIT -1 OFFSET ?', the LIMIT must be present,
+ # while Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
- if len(reserved_users) == 0:
- sql = """
- DELETE FROM monthly_active_users
- WHERE user_id NOT IN (
- SELECT user_id FROM monthly_active_users
- ORDER BY timestamp DESC
- LIMIT ?
- )
- """
- txn.execute(sql, (max_mau_value,))
- # Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
- # when len(reserved_users) == 0. Works fine on sqlite.
- else:
- # Must be >= 0 for postgres
- num_of_non_reserved_users_to_remove = max(
- max_mau_value - len(reserved_users), 0
- )
- # It is important to filter reserved users twice to guard
- # against the case where the reserved user is present in the
- # SELECT, meaning that a legitmate mau is deleted.
- sql = """
- DELETE FROM monthly_active_users
- WHERE user_id NOT IN (
- SELECT user_id FROM monthly_active_users
- WHERE user_id NOT IN ({})
- ORDER BY timestamp DESC
- LIMIT ?
- )
- AND user_id NOT IN ({})
- """.format(
- question_marks, question_marks
+ # Limit must be >= 0 for postgres
+ num_of_non_reserved_users_to_remove = max(
+ self._max_mau_value - len(reserved_users), 0
+ )
+
+ # It is important to filter reserved users twice to guard
+ # against the case where the reserved user is present in the
+ # SELECT, meaning that a legitimate mau is deleted.
+ sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ WHERE NOT %s
+ ORDER BY timestamp DESC
+ LIMIT ?
)
-
- query_args = [
- *reserved_users,
- num_of_non_reserved_users_to_remove,
- *reserved_users,
- ]
-
- txn.execute(sql, query_args)
-
- # It seems poor to invalidate the whole cache, Postgres supports
+ AND NOT %s
+ """ % (
+ in_clause,
+ in_clause,
+ )
+
+ query_args = (
+ in_clause_args
+ + [num_of_non_reserved_users_to_remove]
+ + in_clause_args
+ )
+ txn.execute(sql, query_args)
+
+ # It seems poor to invalidate the whole cache. Postgres supports
# 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead
# I would need to SELECT and the DELETE which without locking
@@ -255,8 +245,8 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
)
self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
- reserved_users = yield self.get_registered_reserved_users()
- yield self.db.runInteraction(
+ reserved_users = await self.get_registered_reserved_users()
+ await self.db.runInteraction(
"reap_monthly_active_users", _reap_users, reserved_users
)
@@ -267,6 +257,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
Args:
user_id (str): user to add/update
+
+ Returns:
+ Deferred
"""
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
@@ -335,7 +328,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
Args:
user_id(str): the user_id to query
"""
- if self.hs.config.limit_usage_by_mau or self.hs.config.mau_stats_only:
+ if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = yield self.is_guest(user_id)
if is_guest:
@@ -356,11 +349,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
# In the case where mau_stats_only is True and limit_usage_by_mau is
# False, there is no point in checking get_monthly_active_count - it
# adds no value and will break the logic if max_mau_value is exceeded.
- if not self.hs.config.limit_usage_by_mau:
+ if not self._limit_usage_by_mau:
yield self.upsert_monthly_active_user(user_id)
else:
count = yield self.get_monthly_active_count()
- if count < self.hs.config.max_mau_value:
+ if count < self._max_mau_value:
yield self.upsert_monthly_active_user(user_id)
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
yield self.upsert_monthly_active_user(user_id)
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index 604c8b7ddd..dab31e0c2d 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -60,7 +60,7 @@ class PresenceStore(SQLBaseStore):
"status_msg": state.status_msg,
"currently_active": state.currently_active,
}
- for state in presence_states
+ for stream_id, state in zip(stream_orderings, presence_states)
],
)
@@ -73,19 +73,22 @@ class PresenceStore(SQLBaseStore):
)
txn.execute(sql + clause, [stream_id] + list(args))
- def get_all_presence_updates(self, last_id, current_id):
+ def get_all_presence_updates(self, last_id, current_id, limit):
if last_id == current_id:
return defer.succeed([])
def get_all_presence_updates_txn(txn):
- sql = (
- "SELECT stream_id, user_id, state, last_active_ts,"
- " last_federation_update_ts, last_user_sync_ts, status_msg,"
- " currently_active"
- " FROM presence_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- )
- txn.execute(sql, (last_id, current_id))
+ sql = """
+ SELECT stream_id, user_id, state, last_active_ts,
+ last_federation_update_ts, last_user_sync_ts,
+ status_msg,
+ currently_active
+ FROM presence_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
return self.db.runInteraction(
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index 2b52cf9c1a..bfc9369f0b 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -110,7 +110,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
- values={
+ updatevalues={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
new file mode 100644
index 0000000000..a93e1ef198
--- /dev/null
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -0,0 +1,399 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Any, Tuple
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.types import RoomStreamToken
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
+ def purge_history(self, room_id, token, delete_local_events):
+ """Deletes room history before a certain point
+
+ Args:
+ room_id (str):
+
+ token (str): A topological token to delete events before
+
+ delete_local_events (bool):
+ if True, we will delete local events as well as remote ones
+ (instead of just marking them as outliers and deleting their
+ state groups).
+
+ Returns:
+ Deferred[set[int]]: The set of state groups that are referenced by
+ deleted events.
+ """
+
+ return self.db.runInteraction(
+ "purge_history",
+ self._purge_history_txn,
+ room_id,
+ token,
+ delete_local_events,
+ )
+
+ def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
+ token = RoomStreamToken.parse(token_str)
+
+ # Tables that should be pruned:
+ # event_auth
+ # event_backward_extremities
+ # event_edges
+ # event_forward_extremities
+ # event_json
+ # event_push_actions
+ # event_reference_hashes
+ # event_search
+ # event_to_state_groups
+ # events
+ # rejections
+ # room_depth
+ # state_groups
+ # state_groups_state
+
+ # we will build a temporary table listing the events so that we don't
+ # have to keep shovelling the list back and forth across the
+ # connection. Annoyingly the python sqlite driver commits the
+ # transaction on CREATE, so let's do this first.
+ #
+ # furthermore, we might already have the table from a previous (failed)
+ # purge attempt, so let's drop the table first.
+
+ txn.execute("DROP TABLE IF EXISTS events_to_purge")
+
+ txn.execute(
+ "CREATE TEMPORARY TABLE events_to_purge ("
+ " event_id TEXT NOT NULL,"
+ " should_delete BOOLEAN NOT NULL"
+ ")"
+ )
+
+ # First ensure that we're not about to delete all the forward extremeties
+ txn.execute(
+ "SELECT e.event_id, e.depth FROM events as e "
+ "INNER JOIN event_forward_extremities as f "
+ "ON e.event_id = f.event_id "
+ "AND e.room_id = f.room_id "
+ "WHERE f.room_id = ?",
+ (room_id,),
+ )
+ rows = txn.fetchall()
+ max_depth = max(row[1] for row in rows)
+
+ if max_depth < token.topological:
+ # We need to ensure we don't delete all the events from the database
+ # otherwise we wouldn't be able to send any events (due to not
+ # having any backwards extremeties)
+ raise SynapseError(
+ 400, "topological_ordering is greater than forward extremeties"
+ )
+
+ logger.info("[purge] looking for events to delete")
+
+ should_delete_expr = "state_key IS NULL"
+ should_delete_params = () # type: Tuple[Any, ...]
+ if not delete_local_events:
+ should_delete_expr += " AND event_id NOT LIKE ?"
+
+ # We include the parameter twice since we use the expression twice
+ should_delete_params += ("%:" + self.hs.hostname, "%:" + self.hs.hostname)
+
+ should_delete_params += (room_id, token.topological)
+
+ # Note that we insert events that are outliers and aren't going to be
+ # deleted, as nothing will happen to them.
+ txn.execute(
+ "INSERT INTO events_to_purge"
+ " SELECT event_id, %s"
+ " FROM events AS e LEFT JOIN state_events USING (event_id)"
+ " WHERE (NOT outlier OR (%s)) AND e.room_id = ? AND topological_ordering < ?"
+ % (should_delete_expr, should_delete_expr),
+ should_delete_params,
+ )
+
+ # We create the indices *after* insertion as that's a lot faster.
+
+ # create an index on should_delete because later we'll be looking for
+ # the should_delete / shouldn't_delete subsets
+ txn.execute(
+ "CREATE INDEX events_to_purge_should_delete"
+ " ON events_to_purge(should_delete)"
+ )
+
+ # We do joins against events_to_purge for e.g. calculating state
+ # groups to purge, etc., so lets make an index.
+ txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
+
+ txn.execute("SELECT event_id, should_delete FROM events_to_purge")
+ event_rows = txn.fetchall()
+ logger.info(
+ "[purge] found %i events before cutoff, of which %i can be deleted",
+ len(event_rows),
+ sum(1 for e in event_rows if e[1]),
+ )
+
+ logger.info("[purge] Finding new backward extremities")
+
+ # We calculate the new entries for the backward extremeties by finding
+ # events to be purged that are pointed to by events we're not going to
+ # purge.
+ txn.execute(
+ "SELECT DISTINCT e.event_id FROM events_to_purge AS e"
+ " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
+ " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
+ " WHERE ep2.event_id IS NULL"
+ )
+ new_backwards_extrems = txn.fetchall()
+
+ logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
+
+ txn.execute(
+ "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,)
+ )
+
+ # Update backward extremeties
+ txn.executemany(
+ "INSERT INTO event_backward_extremities (room_id, event_id)"
+ " VALUES (?, ?)",
+ [(room_id, event_id) for event_id, in new_backwards_extrems],
+ )
+
+ logger.info("[purge] finding state groups referenced by deleted events")
+
+ # Get all state groups that are referenced by events that are to be
+ # deleted.
+ txn.execute(
+ """
+ SELECT DISTINCT state_group FROM events_to_purge
+ INNER JOIN event_to_state_groups USING (event_id)
+ """
+ )
+
+ referenced_state_groups = {sg for sg, in txn}
+ logger.info(
+ "[purge] found %i referenced state groups", len(referenced_state_groups)
+ )
+
+ logger.info("[purge] removing events from event_to_state_groups")
+ txn.execute(
+ "DELETE FROM event_to_state_groups "
+ "WHERE event_id IN (SELECT event_id from events_to_purge)"
+ )
+ for event_id, _ in event_rows:
+ txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
+
+ # Delete all remote non-state events
+ for table in (
+ "events",
+ "event_json",
+ "event_auth",
+ "event_edges",
+ "event_forward_extremities",
+ "event_reference_hashes",
+ "event_search",
+ "rejections",
+ ):
+ logger.info("[purge] removing events from %s", table)
+
+ txn.execute(
+ "DELETE FROM %s WHERE event_id IN ("
+ " SELECT event_id FROM events_to_purge WHERE should_delete"
+ ")" % (table,)
+ )
+
+ # event_push_actions lacks an index on event_id, and has one on
+ # (room_id, event_id) instead.
+ for table in ("event_push_actions",):
+ logger.info("[purge] removing events from %s", table)
+
+ txn.execute(
+ "DELETE FROM %s WHERE room_id = ? AND event_id IN ("
+ " SELECT event_id FROM events_to_purge WHERE should_delete"
+ ")" % (table,),
+ (room_id,),
+ )
+
+ # Mark all state and own events as outliers
+ logger.info("[purge] marking remaining events as outliers")
+ txn.execute(
+ "UPDATE events SET outlier = ?"
+ " WHERE event_id IN ("
+ " SELECT event_id FROM events_to_purge "
+ " WHERE NOT should_delete"
+ ")",
+ (True,),
+ )
+
+ # synapse tries to take out an exclusive lock on room_depth whenever it
+ # persists events (because upsert), and once we run this update, we
+ # will block that for the rest of our transaction.
+ #
+ # So, let's stick it at the end so that we don't block event
+ # persistence.
+ #
+ # We do this by calculating the minimum depth of the backwards
+ # extremities. However, the events in event_backward_extremities
+ # are ones we don't have yet so we need to look at the events that
+ # point to it via event_edges table.
+ txn.execute(
+ """
+ SELECT COALESCE(MIN(depth), 0)
+ FROM event_backward_extremities AS eb
+ INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
+ INNER JOIN events AS e ON e.event_id = eg.event_id
+ WHERE eb.room_id = ?
+ """,
+ (room_id,),
+ )
+ (min_depth,) = txn.fetchone()
+
+ logger.info("[purge] updating room_depth to %d", min_depth)
+
+ txn.execute(
+ "UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
+ (min_depth, room_id),
+ )
+
+ # finally, drop the temp table. this will commit the txn in sqlite,
+ # so make sure to keep this actually last.
+ txn.execute("DROP TABLE events_to_purge")
+
+ logger.info("[purge] done")
+
+ return referenced_state_groups
+
+ def purge_room(self, room_id):
+ """Deletes all record of a room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[List[int]]: The list of state groups to delete.
+ """
+
+ return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
+
+ def _purge_room_txn(self, txn, room_id):
+ # First we fetch all the state groups that should be deleted, before
+ # we delete that information.
+ txn.execute(
+ """
+ SELECT DISTINCT state_group FROM events
+ INNER JOIN event_to_state_groups USING(event_id)
+ WHERE events.room_id = ?
+ """,
+ (room_id,),
+ )
+
+ state_groups = [row[0] for row in txn]
+
+ # Now we delete tables which lack an index on room_id but have one on event_id
+ for table in (
+ "event_auth",
+ "event_edges",
+ "event_push_actions_staging",
+ "event_reference_hashes",
+ "event_relations",
+ "event_to_state_groups",
+ "redactions",
+ "rejections",
+ "state_events",
+ ):
+ logger.info("[purge] removing %s from %s", room_id, table)
+
+ txn.execute(
+ """
+ DELETE FROM %s WHERE event_id IN (
+ SELECT event_id FROM events WHERE room_id=?
+ )
+ """
+ % (table,),
+ (room_id,),
+ )
+
+ # and finally, the tables with an index on room_id (or no useful index)
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ # no useful index, but let's clear them anyway
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ "local_current_membership",
+ ):
+ logger.info("[purge] removing %s from %s", room_id, table)
+ txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
+
+ # Other tables we do NOT need to clear out:
+ #
+ # - blocked_rooms
+ # This is important, to make sure that we don't accidentally rejoin a blocked
+ # room after it was purged
+ #
+ # - user_directory
+ # This has a room_id column, but it is unused
+ #
+
+ # Other tables that we might want to consider clearing out include:
+ #
+ # - event_reports
+ # Given that these are intended for abuse management my initial
+ # inclination is to leave them in place.
+ #
+ # - current_state_delta_stream
+ # - ex_outlier_stream
+ # - room_tags_revisions
+ # The problem with these is that they are largeish and there is no room_id
+ # index on them. In any case we should be clearing out 'stream' tables
+ # periodically anyway (#5888)
+
+ # TODO: we could probably usefully do a bunch of cache invalidation here
+
+ logger.info("[purge] done")
+
+ return state_groups
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index 62ac88d9f2..ef8f40959f 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -16,19 +16,23 @@
import abc
import logging
+from typing import Union
from canonicaljson import json
from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
from synapse.storage.database import Database
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.storage.util.id_generators import ChainedIdGenerator
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -41,6 +45,7 @@ def _load_rules(rawrules, enabled_map):
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
+ rule["default"] = False
ruleslist.append(rule)
# We're going to be mutating this a lot, so do a deep copy
@@ -63,6 +68,7 @@ class PushRulesWorkerStore(
ReceiptsWorkerStore,
PusherWorkerStore,
RoomMemberWorkerStore,
+ EventsWorkerStore,
SQLBaseStore,
):
"""This is an abstract base class where subclasses must implement
@@ -76,6 +82,15 @@ class PushRulesWorkerStore(
def __init__(self, database: Database, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+ if hs.config.worker.worker_app is None:
+ self._push_rules_stream_id_gen = ChainedIdGenerator(
+ self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+ ) # type: Union[ChainedIdGenerator, SlavedIdTracker]
+ else:
+ self._push_rules_stream_id_gen = SlavedIdTracker(
+ db_conn, "push_rules_stream", "stream_id"
+ )
+
push_rules_prefill, push_rules_id = self.db.get_cache_dict(
db_conn,
"push_rules_stream",
@@ -333,6 +348,26 @@ class PushRulesWorkerStore(
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
return results
+ def get_all_push_rule_updates(self, last_id, current_id, limit):
+ """Get all the push rules changes that have happend on the server"""
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_push_rule_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
+ " op, priority_class, priority, conditions, actions"
+ " FROM push_rules_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_push_rule_updates", get_all_push_rule_updates_txn
+ )
+
class PushRuleStore(PushRulesWorkerStore):
@defer.inlineCallbacks
@@ -684,26 +719,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
- def get_all_push_rule_updates(self, last_id, current_id, limit):
- """Get all the push rules changes that have happend on the server"""
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_push_rule_updates_txn(txn):
- sql = (
- "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
- " op, priority_class, priority, conditions, actions"
- " FROM push_rules_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
-
- return self.db.runInteraction(
- "get_all_push_rule_updates", get_all_push_rule_updates_txn
- )
-
def get_push_rules_stream_token(self):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 0d932a0672..cebdcd409f 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -391,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type),
)
- self.db.simple_delete_txn(
+ self.db.simple_upsert_txn(
txn,
table="receipts_linearized",
keyvalues={
@@ -399,19 +399,14 @@ class ReceiptsStore(ReceiptsWorkerStore):
"receipt_type": receipt_type,
"user_id": user_id,
},
- )
-
- self.db.simple_insert_txn(
- txn,
- table="receipts_linearized",
values={
"stream_id": stream_id,
- "room_id": room_id,
- "receipt_type": receipt_type,
- "user_id": user_id,
"event_id": event_id,
"data": json.dumps(data),
},
+ # receipts_linearized has a unique constraint on
+ # (user_id, room_id, receipt_type), so no need to lock
+ lock=False,
)
if receipt_type == "m.read" and stream_ordering is not None:
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index 3e53c8568a..9768981891 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -17,6 +17,7 @@
import logging
import re
+from typing import Optional
from six import iterkeys
@@ -273,8 +274,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
- @defer.inlineCallbacks
- def is_server_admin(self, user):
+ async def is_server_admin(self, user):
"""Determines if a user is an admin of this homeserver.
Args:
@@ -283,7 +283,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool):
true iff the user is a server admin, false otherwise.
"""
- res = yield self.db.simple_select_one_onecol(
+ res = await self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
@@ -343,7 +343,7 @@ class RegistrationWorkerStore(SQLBaseStore):
)
return res
- @cachedInlineCallbacks()
+ @cached()
def is_support_user(self, user_id):
"""Determines if the user is of type UserTypes.SUPPORT
@@ -353,10 +353,9 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT
"""
- res = yield self.db.runInteraction(
+ return self.db.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
- return res
def is_real_user_txn(self, txn, user_id):
res = self.db.simple_select_one_onecol_txn(
@@ -517,18 +516,17 @@ class RegistrationWorkerStore(SQLBaseStore):
)
)
- @defer.inlineCallbacks
- def get_user_id_by_threepid(self, medium, address):
+ async def get_user_id_by_threepid(self, medium: str, address: str) -> Optional[str]:
"""Returns user id from threepid
Args:
- medium (str): threepid medium e.g. email
- address (str): threepid address e.g. me@example.com
+ medium: threepid medium e.g. email
+ address: threepid address e.g. me@example.com
Returns:
- Deferred[str|None]: user id or None if no user id/threepid mapping exists
+ The user ID or None if no user id/threepid mapping exists
"""
- user_id = yield self.db.runInteraction(
+ user_id = await self.db.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
return user_id
@@ -994,7 +992,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Args:
user_id (str): The desired user ID to register.
- password_hash (str): Optional. The password hash for this user.
+ password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
make_guest (boolean): True if the the new user should be guest,
@@ -1008,6 +1006,9 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Raises:
StoreError if the user_id could not be registered.
+
+ Returns:
+ Deferred
"""
return self.db.runInteraction(
"register_user",
diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py
index 1c07c7a425..27e5a2084a 100644
--- a/synapse/storage/data_stores/main/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -21,17 +21,6 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
- def _store_rejections_txn(self, txn, event_id, reason):
- self.db.simple_insert_txn(
- txn,
- table="rejections",
- values={
- "event_id": event_id,
- "reason": reason,
- "last_check": self._clock.time_msec(),
- },
- )
-
def get_rejection_reason(self, event_id):
return self.db.simple_select_one_onecol(
table="rejections",
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
index 046c2b4845..7d477f8d01 100644
--- a/synapse/storage/data_stores/main/relations.py
+++ b/synapse/storage/data_stores/main/relations.py
@@ -324,62 +324,4 @@ class RelationsWorkerStore(SQLBaseStore):
class RelationsStore(RelationsWorkerStore):
- def _handle_event_relations(self, txn, event):
- """Handles inserting relation data during peristence of events
-
- Args:
- txn
- event (EventBase)
- """
- relation = event.content.get("m.relates_to")
- if not relation:
- # No relations
- return
-
- rel_type = relation.get("rel_type")
- if rel_type not in (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.REPLACE,
- ):
- # Unknown relation type
- return
-
- parent_id = relation.get("event_id")
- if not parent_id:
- # Invalid relation
- return
-
- aggregation_key = relation.get("key")
-
- self.db.simple_insert_txn(
- txn,
- table="event_relations",
- values={
- "event_id": event.event_id,
- "relates_to_id": parent_id,
- "relation_type": rel_type,
- "aggregation_key": aggregation_key,
- },
- )
-
- txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
- txn.call_after(
- self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
- )
-
- if rel_type == RelationTypes.REPLACE:
- txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
-
- def _handle_redaction(self, txn, redacted_event_id):
- """Handles receiving a redaction and checking whether we need to remove
- any redacted relations from the database.
-
- Args:
- txn
- redacted_event_id (str): The event that was redacted.
- """
-
- self.db.simple_delete_txn(
- txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
- )
+ pass
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index e6c10c6316..46f643c6b9 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -21,8 +21,6 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
-from six import integer_types
-
from canonicaljson import json
from twisted.internet import defer
@@ -52,12 +50,28 @@ class RoomSortOrder(Enum):
"""
Enum to define the sorting method used when returning rooms with get_rooms_paginate
- ALPHABETICAL = sort rooms alphabetically by name
- SIZE = sort rooms by membership size, highest to lowest
+ NAME = sort rooms alphabetically by name
+ JOINED_MEMBERS = sort rooms by membership size, highest to lowest
"""
+ # ALPHABETICAL and SIZE are deprecated.
+ # ALPHABETICAL is the same as NAME.
ALPHABETICAL = "alphabetical"
+ # SIZE is the same as JOINED_MEMBERS.
SIZE = "size"
+ NAME = "name"
+ CANONICAL_ALIAS = "canonical_alias"
+ JOINED_MEMBERS = "joined_members"
+ JOINED_LOCAL_MEMBERS = "joined_local_members"
+ VERSION = "version"
+ CREATOR = "creator"
+ ENCRYPTION = "encryption"
+ FEDERATABLE = "federatable"
+ PUBLIC = "public"
+ JOIN_RULES = "join_rules"
+ GUEST_ACCESS = "guest_access"
+ HISTORY_VISIBILITY = "history_visibility"
+ STATE_EVENTS = "state_events"
class RoomWorkerStore(SQLBaseStore):
@@ -82,6 +96,37 @@ class RoomWorkerStore(SQLBaseStore):
allow_none=True,
)
+ def get_room_with_stats(self, room_id: str):
+ """Retrieve room with statistics.
+
+ Args:
+ room_id: The ID of the room to retrieve.
+ Returns:
+ A dict containing the room information, or None if the room is unknown.
+ """
+
+ def get_room_with_stats_txn(txn, room_id):
+ sql = """
+ SELECT room_id, state.name, state.canonical_alias, curr.joined_members,
+ curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
+ rooms.creator, state.encryption, state.is_federatable AS federatable,
+ rooms.is_public AS public, state.join_rules, state.guest_access,
+ state.history_visibility, curr.current_state_events AS state_events
+ FROM rooms
+ LEFT JOIN room_stats_state state USING (room_id)
+ LEFT JOIN room_stats_current curr USING (room_id)
+ WHERE room_id = ?
+ """
+ txn.execute(sql, [room_id])
+ res = self.db.cursor_to_dict(txn)[0]
+ res["federatable"] = bool(res["federatable"])
+ res["public"] = bool(res["public"])
+ return res
+
+ return self.db.runInteraction(
+ "get_room_with_stats", get_room_with_stats_txn, room_id
+ )
+
def get_public_room_ids(self):
return self.db.simple_select_onecol(
table="rooms",
@@ -329,12 +374,52 @@ class RoomWorkerStore(SQLBaseStore):
# Set ordering
if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
+ # Deprecated in favour of RoomSortOrder.JOINED_MEMBERS
order_by_column = "curr.joined_members"
order_by_asc = False
elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL:
- # Sort alphabetically
+ # Deprecated in favour of RoomSortOrder.NAME
+ order_by_column = "state.name"
+ order_by_asc = True
+ elif RoomSortOrder(order_by) == RoomSortOrder.NAME:
order_by_column = "state.name"
order_by_asc = True
+ elif RoomSortOrder(order_by) == RoomSortOrder.CANONICAL_ALIAS:
+ order_by_column = "state.canonical_alias"
+ order_by_asc = True
+ elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_MEMBERS:
+ order_by_column = "curr.joined_members"
+ order_by_asc = False
+ elif RoomSortOrder(order_by) == RoomSortOrder.JOINED_LOCAL_MEMBERS:
+ order_by_column = "curr.local_users_in_room"
+ order_by_asc = False
+ elif RoomSortOrder(order_by) == RoomSortOrder.VERSION:
+ order_by_column = "rooms.room_version"
+ order_by_asc = False
+ elif RoomSortOrder(order_by) == RoomSortOrder.CREATOR:
+ order_by_column = "rooms.creator"
+ order_by_asc = True
+ elif RoomSortOrder(order_by) == RoomSortOrder.ENCRYPTION:
+ order_by_column = "state.encryption"
+ order_by_asc = True
+ elif RoomSortOrder(order_by) == RoomSortOrder.FEDERATABLE:
+ order_by_column = "state.is_federatable"
+ order_by_asc = True
+ elif RoomSortOrder(order_by) == RoomSortOrder.PUBLIC:
+ order_by_column = "rooms.is_public"
+ order_by_asc = True
+ elif RoomSortOrder(order_by) == RoomSortOrder.JOIN_RULES:
+ order_by_column = "state.join_rules"
+ order_by_asc = True
+ elif RoomSortOrder(order_by) == RoomSortOrder.GUEST_ACCESS:
+ order_by_column = "state.guest_access"
+ order_by_asc = True
+ elif RoomSortOrder(order_by) == RoomSortOrder.HISTORY_VISIBILITY:
+ order_by_column = "state.history_visibility"
+ order_by_asc = True
+ elif RoomSortOrder(order_by) == RoomSortOrder.STATE_EVENTS:
+ order_by_column = "curr.current_state_events"
+ order_by_asc = False
else:
raise StoreError(
500, "Incorrect value for order_by provided: %s" % order_by
@@ -349,9 +434,13 @@ class RoomWorkerStore(SQLBaseStore):
# for, and another query for getting the total number of events that could be
# returned. Thus allowing us to see if there are more events to paginate through
info_sql = """
- SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members
+ SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members,
+ curr.local_users_in_room, rooms.room_version, rooms.creator,
+ state.encryption, state.is_federatable, rooms.is_public, state.join_rules,
+ state.guest_access, state.history_visibility, curr.current_state_events
FROM room_stats_state state
INNER JOIN room_stats_current curr USING (room_id)
+ INNER JOIN rooms USING (room_id)
%s
ORDER BY %s %s
LIMIT ?
@@ -389,6 +478,16 @@ class RoomWorkerStore(SQLBaseStore):
"name": room[1],
"canonical_alias": room[2],
"joined_members": room[3],
+ "joined_local_members": room[4],
+ "version": room[5],
+ "creator": room[6],
+ "encryption": room[7],
+ "federatable": room[8],
+ "public": room[9],
+ "join_rules": room[10],
+ "guest_access": room[11],
+ "history_visibility": room[12],
+ "state_events": room[13],
}
)
@@ -732,6 +831,26 @@ class RoomWorkerStore(SQLBaseStore):
return total_media_quarantined
+ def get_all_new_public_rooms(self, prev_id, current_id, limit):
+ def get_all_new_public_rooms(txn):
+ sql = """
+ SELECT stream_id, room_id, visibility, appservice_id, network_id
+ FROM public_room_list_stream
+ WHERE stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (prev_id, current_id, limit))
+ return txn.fetchall()
+
+ if prev_id == current_id:
+ return defer.succeed([])
+
+ return self.db.runInteraction(
+ "get_all_new_public_rooms", get_all_new_public_rooms
+ )
+
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1181,53 +1300,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return self.db.runInteraction("get_rooms", f)
- def _store_room_topic_txn(self, txn, event):
- if hasattr(event, "content") and "topic" in event.content:
- self.store_event_search_txn(
- txn, event, "content.topic", event.content["topic"]
- )
-
- def _store_room_name_txn(self, txn, event):
- if hasattr(event, "content") and "name" in event.content:
- self.store_event_search_txn(
- txn, event, "content.name", event.content["name"]
- )
-
- def _store_room_message_txn(self, txn, event):
- if hasattr(event, "content") and "body" in event.content:
- self.store_event_search_txn(
- txn, event, "content.body", event.content["body"]
- )
-
- def _store_retention_policy_for_room_txn(self, txn, event):
- if hasattr(event, "content") and (
- "min_lifetime" in event.content or "max_lifetime" in event.content
- ):
- if (
- "min_lifetime" in event.content
- and not isinstance(event.content.get("min_lifetime"), integer_types)
- ) or (
- "max_lifetime" in event.content
- and not isinstance(event.content.get("max_lifetime"), integer_types)
- ):
- # Ignore the event if one of the value isn't an integer.
- return
-
- self.db.simple_insert_txn(
- txn=txn,
- table="room_retention",
- values={
- "room_id": event.room_id,
- "event_id": event.event_id,
- "min_lifetime": event.content.get("min_lifetime"),
- "max_lifetime": event.content.get("max_lifetime"),
- },
- )
-
- self._invalidate_cache_and_stream(
- txn, self.get_retention_policy_for_room, (event.room_id,)
- )
-
def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts
):
@@ -1249,26 +1321,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- def get_all_new_public_rooms(self, prev_id, current_id, limit):
- def get_all_new_public_rooms(txn):
- sql = """
- SELECT stream_id, room_id, visibility, appservice_id, network_id
- FROM public_room_list_stream
- WHERE stream_id > ? AND stream_id <= ?
- ORDER BY stream_id ASC
- LIMIT ?
- """
-
- txn.execute(sql, (prev_id, current_id, limit))
- return txn.fetchall()
-
- if prev_id == current_id:
- return defer.succeed([])
-
- return self.db.runInteraction(
- "get_all_new_public_rooms", get_all_new_public_rooms
- )
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
"""Marks the room as blocked. Can be called multiple times.
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index d5bd0cb5cf..137ebac833 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -45,7 +45,6 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.metrics import Measure
-from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@@ -153,16 +152,6 @@ class RoomMemberWorkerStore(EventsWorkerStore):
self._check_safe_current_state_events_membership_updated_txn,
)
- @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
- def get_hosts_in_room(self, room_id, cache_context):
- """Returns the set of all hosts currently in the room
- """
- user_ids = yield self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
- return hosts
-
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
return self.db.runInteraction(
@@ -189,7 +178,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (room_id, Membership.JOIN))
- return [to_ascii(r[0]) for r in txn]
+ return [r[0] for r in txn]
@cached(max_entries=100000)
def get_room_summary(self, room_id):
@@ -233,7 +222,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(sql, (room_id,))
res = {}
for count, membership in txn:
- summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
+ summary = res.setdefault(membership, MemberSummary([], count))
# we order by membership and then fairly arbitrarily by event_id so
# heroes are consistent
@@ -265,11 +254,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
for user_id, membership, event_id in txn:
- summary = res[to_ascii(membership)]
+ summary = res[membership]
# we will always have a summary for this membership type at this
# point given the summary currently contains the counts.
members = summary.members
- members.append((to_ascii(user_id), to_ascii(event_id)))
+ members.append((user_id, event_id))
return res
@@ -576,7 +565,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if key[0] == EventTypes.Member
]
for etype, state_key in context.delta_ids:
- users_in_room.pop(state_key, None)
+ if etype == EventTypes.Member:
+ users_in_room.pop(state_key, None)
# We check if we have any of the member event ids in the event cache
# before we ask the DB
@@ -593,13 +583,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
ev_entry = event_map.get(event_id)
if ev_entry:
if ev_entry.event.membership == Membership.JOIN:
- users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
- display_name=to_ascii(
- ev_entry.event.content.get("displayname", None)
- ),
- avatar_url=to_ascii(
- ev_entry.event.content.get("avatar_url", None)
- ),
+ users_in_room[ev_entry.event.state_key] = ProfileInfo(
+ display_name=ev_entry.event.content.get("displayname", None),
+ avatar_url=ev_entry.event.content.get("avatar_url", None),
)
else:
missing_member_event_ids.append(event_id)
@@ -613,9 +599,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
if event.event_id in member_event_ids:
- users_in_room[to_ascii(event.state_key)] = ProfileInfo(
- display_name=to_ascii(event.content.get("displayname", None)),
- avatar_url=to_ascii(event.content.get("avatar_url", None)),
+ users_in_room[event.state_key] = ProfileInfo(
+ display_name=event.content.get("displayname", None),
+ avatar_url=event.content.get("avatar_url", None),
)
return users_in_room
@@ -1060,119 +1046,6 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
super(RoomMemberStore, self).__init__(database, db_conn, hs)
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
- self.db.simple_insert_many_txn(
- txn,
- table="room_memberships",
- values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
- }
- for event in events
- ],
- )
-
- for event in events:
- txn.call_after(
- self._membership_stream_cache.entity_has_changed,
- event.state_key,
- event.internal_metadata.stream_ordering,
- )
- txn.call_after(
- self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
- )
-
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened. If the event is an
- # outlier it is only current if its an "out of band membership",
- # like a remote invite or a rejection of a remote invite.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_out_of_band_membership()
- )
- is_mine = self.hs.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self.db.simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- },
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(
- sql,
- (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ),
- )
-
- # We also update the `local_current_membership` table with
- # latest invite info. This will usually get updated by the
- # `current_state_events` handling, unless its an outlier.
- if event.internal_metadata.is_outlier():
- # This should only happen for out of band memberships, so
- # we add a paranoia check.
- assert event.internal_metadata.is_out_of_band_membership()
-
- self.db.simple_upsert_txn(
- txn,
- table="local_current_membership",
- keyvalues={
- "room_id": event.room_id,
- "user_id": event.state_key,
- },
- values={
- "event_id": event.event_id,
- "membership": event.membership,
- },
- )
-
- @defer.inlineCallbacks
- def locally_reject_invite(self, user_id, room_id):
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
- # We also clear this entry from `local_current_membership`.
- # Ideally we'd point to a leave event, but we don't have one, so
- # nevermind.
- self.db.simple_delete_txn(
- txn,
- table="local_current_membership",
- keyvalues={"room_id": room_id, "user_id": user_id},
- )
-
- with self._stream_id_gen.get_next() as stream_ordering:
- yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
-
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
index 163529c071..bbdde121e8 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
@@ -35,9 +35,13 @@ DELETE FROM background_updates WHERE update_name IN (
'populate_stats_cleanup'
);
+-- this relies on current_state_events.membership having been populated, so add
+-- a dependency on current_state_events_membership.
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('populate_stats_process_rooms', '{}', '');
+ ('populate_stats_process_rooms', '{}', 'current_state_events_membership');
+-- this also relies on current_state_events.membership having been populated, but
+-- we get that as a side-effect of depending on populate_stats_process_rooms.
INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
('populate_stats_process_users', '{}', 'populate_stats_process_rooms');
diff --git a/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql
new file mode 100644
index 0000000000..133d80af35
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/remove_sent_outbound_pokes.sql
@@ -0,0 +1,21 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- we no longer keep sent outbound device pokes in the db; clear them out
+-- so that we don't have to worry about them.
+--
+-- This is a sequence scan, but it doesn't take too long.
+
+DELETE FROM device_lists_outbound_pokes WHERE sent;
diff --git a/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql
new file mode 100644
index 0000000000..fdc39e9ba5
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/02remove_dup_outbound_pokes.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ /* for some reason, we have accumulated duplicate entries in
+ * device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
+ * efficient.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json)
+ VALUES (5800, 'remove_dup_outbound_pokes', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql
new file mode 100644
index 0000000000..dcb593fc2d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql
@@ -0,0 +1,36 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS ui_auth_sessions(
+ session_id TEXT NOT NULL, -- The session ID passed to the client.
+ creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds).
+ serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse.
+ clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client.
+ uri TEXT NOT NULL, -- The URI the UI authentication session is using.
+ method TEXT NOT NULL, -- The HTTP method the UI authentication session is using.
+ -- The clientdict, uri, and method make up an tuple that must be immutable
+ -- throughout the lifetime of the UI Auth session.
+ description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur.
+ UNIQUE (session_id)
+);
+
+CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials(
+ session_id TEXT NOT NULL, -- The corresponding UI Auth session.
+ stage_type TEXT NOT NULL, -- The stage type.
+ result TEXT NOT NULL, -- The result of the stage verification, stored as JSON.
+ UNIQUE (session_id, stage_type),
+ FOREIGN KEY (session_id)
+ REFERENCES ui_auth_sessions (session_id)
+);
diff --git a/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
new file mode 100644
index 0000000000..aa46eb0e10
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/05cache_instance.sql.postgres
@@ -0,0 +1,30 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We keep the old table here to enable us to roll back. It doesn't matter
+-- that we have dropped all the data here.
+TRUNCATE cache_invalidation_stream;
+
+CREATE TABLE cache_invalidation_stream_by_instance (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ cache_func TEXT NOT NULL,
+ keys TEXT[],
+ invalidation_ts BIGINT
+);
+
+CREATE UNIQUE INDEX cache_invalidation_stream_by_instance_id ON cache_invalidation_stream_by_instance(stream_id);
+
+CREATE SEQUENCE cache_invalidation_stream_seq;
diff --git a/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py
new file mode 100644
index 0000000000..d353f2bcb3
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/06dlols_unique_idx.py
@@ -0,0 +1,80 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This migration rebuilds the device_lists_outbound_last_success table without duplicate
+entries, and with a UNIQUE index.
+"""
+
+import logging
+from io import StringIO
+
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.prepare_database import execute_statements_from_stream
+from synapse.storage.types import Cursor
+
+logger = logging.getLogger(__name__)
+
+
+def run_upgrade(*args, **kwargs):
+ pass
+
+
+def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs):
+ # some instances might already have this index, in which case we can skip this
+ if isinstance(database_engine, PostgresEngine):
+ cur.execute(
+ """
+ SELECT 1 FROM pg_class WHERE relkind = 'i'
+ AND relname = 'device_lists_outbound_last_success_unique_idx'
+ """
+ )
+
+ if cur.rowcount:
+ logger.info(
+ "Unique index exists on device_lists_outbound_last_success: "
+ "skipping rebuild"
+ )
+ return
+
+ logger.info("Rebuilding device_lists_outbound_last_success with unique index")
+ execute_statements_from_stream(cur, StringIO(_rebuild_commands))
+
+
+# there might be duplicates, so the easiest way to achieve this is to create a new
+# table with the right data, and renaming it into place
+
+_rebuild_commands = """
+DROP TABLE IF EXISTS device_lists_outbound_last_success_new;
+
+CREATE TABLE device_lists_outbound_last_success_new (
+ destination TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ stream_id BIGINT NOT NULL
+);
+
+-- this took about 30 seconds on matrix.org's 16 million rows.
+INSERT INTO device_lists_outbound_last_success_new
+ SELECT destination, user_id, MAX(stream_id) FROM device_lists_outbound_last_success
+ GROUP BY destination, user_id;
+
+-- and this another 30 seconds.
+CREATE UNIQUE INDEX device_lists_outbound_last_success_unique_idx
+ ON device_lists_outbound_last_success_new (destination, user_id);
+
+DROP TABLE device_lists_outbound_last_success;
+
+ALTER TABLE device_lists_outbound_last_success_new
+ RENAME TO device_lists_outbound_last_success;
+"""
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 47ebb8a214..13f49d8060 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -37,7 +37,55 @@ SearchEntry = namedtuple(
)
-class SearchBackgroundUpdateStore(SQLBaseStore):
+class SearchWorkerStore(SQLBaseStore):
+ def store_search_entries_txn(self, txn, entries):
+ """Add entries to the search table
+
+ Args:
+ txn (cursor):
+ entries (iterable[SearchEntry]):
+ entries to be added to the table
+ """
+ if not self.hs.config.enable_search:
+ return
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = (
+ "INSERT INTO event_search"
+ " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+ " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+ )
+
+ args = (
+ (
+ entry.event_id,
+ entry.room_id,
+ entry.key,
+ entry.value,
+ entry.stream_ordering,
+ entry.origin_server_ts,
+ )
+ for entry in entries
+ )
+
+ txn.executemany(sql, args)
+
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = (
+ "INSERT INTO event_search (event_id, room_id, key, value)"
+ " VALUES (?,?,?,?)"
+ )
+ args = (
+ (entry.event_id, entry.room_id, entry.key, entry.value)
+ for entry in entries
+ )
+
+ txn.executemany(sql, args)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
+
+class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@@ -296,80 +344,11 @@ class SearchBackgroundUpdateStore(SQLBaseStore):
return num_rows
- def store_search_entries_txn(self, txn, entries):
- """Add entries to the search table
-
- Args:
- txn (cursor):
- entries (iterable[SearchEntry]):
- entries to be added to the table
- """
- if not self.hs.config.enable_search:
- return
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search"
- " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
- " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
- )
-
- args = (
- (
- entry.event_id,
- entry.room_id,
- entry.key,
- entry.value,
- entry.stream_ordering,
- entry.origin_server_ts,
- )
- for entry in entries
- )
-
- txn.executemany(sql, args)
-
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- args = (
- (entry.event_id, entry.room_id, entry.key, entry.value)
- for entry in entries
- )
-
- txn.executemany(sql, args)
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
super(SearchStore, self).__init__(database, db_conn, hs)
- def store_event_search_txn(self, txn, event, key, value):
- """Add event to the search table
-
- Args:
- txn (cursor):
- event (EventBase):
- key (str):
- value (str):
- """
- self.store_search_entries_txn(
- txn,
- (
- SearchEntry(
- key=key,
- value=value,
- event_id=event.event_id,
- room_id=event.room_id,
- stream_ordering=event.internal_metadata.stream_ordering,
- origin_server_ts=event.origin_server_ts,
- ),
- ),
- )
-
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py
index 563216b63c..36244d9f5d 100644
--- a/synapse/storage/data_stores/main/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -13,23 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import six
-
from unpaddedbase64 import encode_base64
from twisted.internet import defer
-from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
-
class SignatureWorkerStore(SQLBaseStore):
@cached()
@@ -79,23 +69,3 @@ class SignatureWorkerStore(SQLBaseStore):
class SignatureStore(SignatureWorkerStore):
"""Persistence for event signatures and hashes"""
-
- def _store_event_reference_hashes_txn(self, txn, events):
- """Store a hash for a PDU
- Args:
- txn (cursor):
- events (list): list of Events.
- """
-
- vals = []
- for event in events:
- ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
- vals.append(
- {
- "event_id": event.event_id,
- "algorithm": ref_alg,
- "hash": db_binary_type(ref_hash_bytes),
- }
- )
-
- self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 3a3b9a8e72..347cc50778 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -16,17 +16,12 @@
import collections.abc
import logging
from collections import namedtuple
-from typing import Iterable, Tuple
-
-from six import iteritems
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
-from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
@@ -34,7 +29,6 @@ from synapse.storage.database import Database
from synapse.storage.state import StateFilter
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.stringutils import to_ascii
logger = logging.getLogger(__name__)
@@ -190,9 +184,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
(room_id,),
)
- return {
- (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
- }
+ return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
return self.db.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
@@ -473,33 +465,3 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
def __init__(self, database: Database, db_conn, hs):
super(StateStore, self).__init__(database, db_conn, hs)
-
- def _store_event_state_mappings_txn(
- self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
- ):
- state_groups = {}
- for event, context in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
-
- # if the event was rejected, just give it the same state as its
- # predecessor.
- if context.rejected:
- state_groups[event.event_id] = context.state_group_before_event
- continue
-
- state_groups[event.event_id] = context.state_group
-
- self.db.simple_insert_many_txn(
- txn,
- table="event_to_state_groups",
- values=[
- {"state_group": state_group_id, "event_id": event_id}
- for event_id, state_group_id in iteritems(state_groups)
- ],
- )
-
- for event_id, state_group_id in iteritems(state_groups):
- txn.call_after(
- self._get_state_group_for_event.prefill, (event_id,), state_group_id
- )
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index ada5cce6c2..e89f0bffb5 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -481,11 +481,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id, limit, end_token
)
- logger.debug("stream before")
events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
- logger.debug("stream after")
self._set_before_and_after(events, rows)
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 2aa1bafd48..4219018302 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -233,6 +233,9 @@ class TagsStore(TagsWorkerStore):
self._account_data_stream_cache.entity_has_changed, user_id, next_id
)
+ # Note: This is only here for backwards compat to allow admins to
+ # roll back to a previous Synapse version. Next time we update the
+ # database version we can remove this table.
update_max_id_sql = (
"UPDATE account_data_max_stream_id"
" SET stream_id = ?"
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py
index 5b07c2fbc0..a9bf457939 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -16,8 +16,6 @@
import logging
from collections import namedtuple
-import six
-
from canonicaljson import encode_canonical_json
from twisted.internet import defer
@@ -27,12 +25,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
+db_binary_type = memoryview
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py
new file mode 100644
index 0000000000..1d8ee22fb1
--- /dev/null
+++ b/synapse/storage/data_stores/main/ui_auth.py
@@ -0,0 +1,300 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+from typing import Any, Dict, Optional, Union
+
+import attr
+
+import synapse.util.stringutils as stringutils
+from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
+from synapse.types import JsonDict
+
+
+@attr.s
+class UIAuthSessionData:
+ session_id = attr.ib(type=str)
+ # The dictionary from the client root level, not the 'auth' key.
+ clientdict = attr.ib(type=JsonDict)
+ # The URI and method the session was intiatied with. These are checked at
+ # each stage of the authentication to ensure that the asked for operation
+ # has not changed.
+ uri = attr.ib(type=str)
+ method = attr.ib(type=str)
+ # A string description of the operation that the current authentication is
+ # authorising.
+ description = attr.ib(type=str)
+
+
+class UIAuthWorkerStore(SQLBaseStore):
+ """
+ Manage user interactive authentication sessions.
+ """
+
+ async def create_ui_auth_session(
+ self, clientdict: JsonDict, uri: str, method: str, description: str,
+ ) -> UIAuthSessionData:
+ """
+ Creates a new user interactive authentication session.
+
+ The session can be used to track the stages necessary to authenticate a
+ user across multiple HTTP requests.
+
+ Args:
+ clientdict:
+ The dictionary from the client root level, not the 'auth' key.
+ uri:
+ The URI this session was initiated with, this is checked at each
+ stage of the authentication to ensure that the asked for
+ operation has not changed.
+ method:
+ The method this session was initiated with, this is checked at each
+ stage of the authentication to ensure that the asked for
+ operation has not changed.
+ description:
+ A string description of the operation that the current
+ authentication is authorising.
+ Returns:
+ The newly created session.
+ Raises:
+ StoreError if a unique session ID cannot be generated.
+ """
+ # The clientdict gets stored as JSON.
+ clientdict_json = json.dumps(clientdict)
+
+ # autogen a session ID and try to create it. We may clash, so just
+ # try a few times till one goes through, giving up eventually.
+ attempts = 0
+ while attempts < 5:
+ session_id = stringutils.random_string(24)
+
+ try:
+ await self.db.simple_insert(
+ table="ui_auth_sessions",
+ values={
+ "session_id": session_id,
+ "clientdict": clientdict_json,
+ "uri": uri,
+ "method": method,
+ "description": description,
+ "serverdict": "{}",
+ "creation_time": self.hs.get_clock().time_msec(),
+ },
+ desc="create_ui_auth_session",
+ )
+ return UIAuthSessionData(
+ session_id, clientdict, uri, method, description
+ )
+ except self.db.engine.module.IntegrityError:
+ attempts += 1
+ raise StoreError(500, "Couldn't generate a session ID.")
+
+ async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
+ """Retrieve a UI auth session.
+
+ Args:
+ session_id: The ID of the session.
+ Returns:
+ A dict containing the device information.
+ Raises:
+ StoreError if the session is not found.
+ """
+ result = await self.db.simple_select_one(
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("clientdict", "uri", "method", "description"),
+ desc="get_ui_auth_session",
+ )
+
+ result["clientdict"] = json.loads(result["clientdict"])
+
+ return UIAuthSessionData(session_id, **result)
+
+ async def mark_ui_auth_stage_complete(
+ self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
+ ):
+ """
+ Mark a session stage as completed.
+
+ Args:
+ session_id: The ID of the corresponding session.
+ stage_type: The completed stage type.
+ result: The result of the stage verification.
+ Raises:
+ StoreError if the session cannot be found.
+ """
+ # Add (or update) the results of the current stage to the database.
+ #
+ # Note that we need to allow for the same stage to complete multiple
+ # times here so that registration is idempotent.
+ try:
+ await self.db.simple_upsert(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id, "stage_type": stage_type},
+ values={"result": json.dumps(result)},
+ desc="mark_ui_auth_stage_complete",
+ )
+ except self.db.engine.module.IntegrityError:
+ raise StoreError(400, "Unknown session ID: %s" % (session_id,))
+
+ async def get_completed_ui_auth_stages(
+ self, session_id: str
+ ) -> Dict[str, Union[str, bool, JsonDict]]:
+ """
+ Retrieve the completed stages of a UI authentication session.
+
+ Args:
+ session_id: The ID of the session.
+ Returns:
+ The completed stages mapped to the result of the verification of
+ that auth-type.
+ """
+ results = {}
+ for row in await self.db.simple_select_list(
+ table="ui_auth_sessions_credentials",
+ keyvalues={"session_id": session_id},
+ retcols=("stage_type", "result"),
+ desc="get_completed_ui_auth_stages",
+ ):
+ results[row["stage_type"]] = json.loads(row["result"])
+
+ return results
+
+ async def set_ui_auth_clientdict(
+ self, session_id: str, clientdict: JsonDict
+ ) -> None:
+ """
+ Store an updated clientdict for a given session ID.
+
+ Args:
+ session_id: The ID of this session as returned from check_auth
+ clientdict:
+ The dictionary from the client root level, not the 'auth' key.
+ """
+ # The clientdict gets stored as JSON.
+ clientdict_json = json.dumps(clientdict)
+
+ self.db.simple_update_one(
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ updatevalues={"clientdict": clientdict_json},
+ desc="set_ui_auth_client_dict",
+ )
+
+ async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
+ """
+ Store a key-value pair into the sessions data associated with this
+ request. This data is stored server-side and cannot be modified by
+ the client.
+
+ Args:
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ value: The data to store
+ Raises:
+ StoreError if the session cannot be found.
+ """
+ await self.db.runInteraction(
+ "set_ui_auth_session_data",
+ self._set_ui_auth_session_data_txn,
+ session_id,
+ key,
+ value,
+ )
+
+ def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+ # Get the current value.
+ result = self.db.simple_select_one_txn(
+ txn,
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("serverdict",),
+ )
+
+ # Update it and add it back to the database.
+ serverdict = json.loads(result["serverdict"])
+ serverdict[key] = value
+
+ self.db.simple_update_one_txn(
+ txn,
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ updatevalues={"serverdict": json.dumps(serverdict)},
+ )
+
+ async def get_ui_auth_session_data(
+ self, session_id: str, key: str, default: Optional[Any] = None
+ ) -> Any:
+ """
+ Retrieve data stored with set_session_data
+
+ Args:
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ default: Value to return if the key has not been set
+ Raises:
+ StoreError if the session cannot be found.
+ """
+ result = await self.db.simple_select_one(
+ table="ui_auth_sessions",
+ keyvalues={"session_id": session_id},
+ retcols=("serverdict",),
+ desc="get_ui_auth_session_data",
+ )
+
+ serverdict = json.loads(result["serverdict"])
+
+ return serverdict.get(key, default)
+
+
+class UIAuthStore(UIAuthWorkerStore):
+ def delete_old_ui_auth_sessions(self, expiration_time: int):
+ """
+ Remove sessions which were last used earlier than the expiration time.
+
+ Args:
+ expiration_time: The latest time that is still considered valid.
+ This is an epoch time in milliseconds.
+
+ """
+ return self.db.runInteraction(
+ "delete_old_ui_auth_sessions",
+ self._delete_old_ui_auth_sessions_txn,
+ expiration_time,
+ )
+
+ def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+ # Get the expired sessions.
+ sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
+ txn.execute(sql, [expiration_time])
+ session_ids = [r[0] for r in txn.fetchall()]
+
+ # Delete the corresponding completed credentials.
+ self.db.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions_credentials",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
+
+ # Finally, delete the sessions.
+ self.db.simple_delete_many_txn(
+ txn,
+ table="ui_auth_sessions",
+ column="session_id",
+ iterable=session_ids,
+ keyvalues={},
+ )
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py
index e8edaf9f7b..ff000bc9ec 100644
--- a/synapse/storage/data_stores/state/bg_updates.py
+++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -109,20 +109,20 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
- SELECT DISTINCT type, state_key, last_value(event_id) OVER (
- PARTITION BY type, state_key ORDER BY state_group ASC
- ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
- ) AS event_id FROM state_groups_state
+ SELECT DISTINCT ON (type, state_key)
+ type, state_key, event_id
+ FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
- )
+ ) %s
+ ORDER BY type, state_key, state_group DESC
"""
for group in groups:
args = [group]
args.extend(where_args)
- txn.execute(sql + where_clause, args)
+ txn.execute(sql % (where_clause,), args)
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 57a5267663..f3ad1e4369 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -28,7 +28,6 @@ from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateSt
from synapse.storage.database import Database
from synapse.storage.state import StateFilter
from synapse.types import StateMap
-from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -90,11 +89,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self._state_group_cache = DictionaryCache(
"*stateGroupCache*",
# TODO: this hasn't been tuned yet
- 50000 * get_cache_factor_for("stateGroupCache"),
+ 50000,
)
self._state_group_members_cache = DictionaryCache(
- "*stateGroupMembersCache*",
- 500000 * get_cache_factor_for("stateGroupMembersCache"),
+ "*stateGroupMembersCache*", 500000,
)
@cached(max_entries=10000, iterable=True)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index e61595336c..b112ff3df2 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -17,7 +17,17 @@
import logging
import time
from time import monotonic as monotonic_time
-from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+)
from six import iteritems, iterkeys, itervalues
from six.moves import intern, range
@@ -32,13 +42,14 @@ from synapse.config.database import DatabaseConnectionConfig
from synapse.logging.context import (
LoggingContext,
LoggingContextOrSentinel,
+ current_context,
make_deferred_yieldable,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
-from synapse.util.stringutils import exception_to_unicode
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -201,9 +212,9 @@ class LoggingTransaction:
def executemany(self, sql: str, *args: Any):
self._do_execute(self.txn.executemany, sql, *args)
- def _make_sql_one_line(self, sql):
+ def _make_sql_one_line(self, sql: str) -> str:
"Strip newlines out of SQL so that the loggers in the DB are on one line"
- return " ".join(l.strip() for l in sql.splitlines() if l.strip())
+ return " ".join(line.strip() for line in sql.splitlines() if line.strip())
def _do_execute(self, func, sql, *args):
sql = self._make_sql_one_line(sql)
@@ -411,20 +422,14 @@ class Database(object):
# This can happen if the database disappears mid
# transaction.
logger.warning(
- "[TXN OPERROR] {%s} %s %d/%d",
- name,
- exception_to_unicode(e),
- i,
- N,
+ "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
)
if i < N:
i += 1
try:
conn.rollback()
except self.engine.module.Error as e1:
- logger.warning(
- "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
- )
+ logger.warning("[TXN EROLL] {%s} %s", name, e1)
continue
raise
except self.engine.module.DatabaseError as e:
@@ -436,9 +441,7 @@ class Database(object):
conn.rollback()
except self.engine.module.Error as e1:
logger.warning(
- "[TXN EROLL] {%s} %s",
- name,
- exception_to_unicode(e1),
+ "[TXN EROLL] {%s} %s", name, e1,
)
continue
raise
@@ -483,7 +486,7 @@ class Database(object):
end = monotonic_time()
duration = end - start
- LoggingContext.current_context().add_database_transaction(duration)
+ current_context().add_database_transaction(duration)
transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
@@ -510,7 +513,7 @@ class Database(object):
after_callbacks = [] # type: List[_CallbackListEntry]
exception_callbacks = [] # type: List[_CallbackListEntry]
- if LoggingContext.current_context() == LoggingContext.sentinel:
+ if not current_context():
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
@@ -547,10 +550,8 @@ class Database(object):
Returns:
Deferred: The result of func
"""
- parent_context = (
- LoggingContext.current_context()
- ) # type: Optional[LoggingContextOrSentinel]
- if parent_context == LoggingContext.sentinel:
+ parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
+ if not parent_context:
logger.warning(
"Starting db connection from sentinel context: metrics will be lost"
)
@@ -880,20 +881,24 @@ class Database(object):
txn.execute(sql, list(allvalues.values()))
def simple_upsert_many_txn(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[str]],
+ ) -> None:
"""
Upsert, many times.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_many_txn_native_upsert(
@@ -905,20 +910,24 @@ class Database(object):
)
def simple_upsert_many_txn_emulated(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Iterable[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[str]],
+ ) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
# No value columns, therefore make a blank list so that the following
# zip() works correctly.
@@ -932,20 +941,24 @@ class Database(object):
self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
def simple_upsert_many_txn_native_upsert(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
+ self,
+ txn: LoggingTransaction,
+ table: str,
+ key_names: Collection[str],
+ key_values: Collection[Iterable[Any]],
+ value_names: Collection[str],
+ value_values: Iterable[Iterable[Any]],
+ ) -> None:
"""
Upsert, many times, using batching where possible.
Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
+ table: The table to upsert into
+ key_names: The key column names.
+ key_values: A list of each row's key column values.
+ value_names: The value column names
+ value_values: A list of each row's value column values.
+ Ignored if value_names is empty.
"""
allnames = [] # type: List[str]
allnames.extend(key_names)
@@ -1558,3 +1571,74 @@ def make_in_list_sql_clause(
return "%s = ANY(?)" % (column,), [list(iterable)]
else:
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
+
+
+KV = TypeVar("KV")
+
+
+def make_tuple_comparison_clause(
+ database_engine: BaseDatabaseEngine, keys: List[Tuple[str, KV]]
+) -> Tuple[str, List[KV]]:
+ """Returns a tuple comparison SQL clause
+
+ Depending what the SQL engine supports, builds a SQL clause that looks like either
+ "(a, b) > (?, ?)", or "(a > ?) OR (a == ? AND b > ?)".
+
+ Args:
+ database_engine
+ keys: A set of (column, value) pairs to be compared.
+
+ Returns:
+ A tuple of SQL query and the args
+ """
+ if database_engine.supports_tuple_comparison:
+ return (
+ "(%s) > (%s)" % (",".join(k[0] for k in keys), ",".join("?" for _ in keys)),
+ [k[1] for k in keys],
+ )
+
+ # we want to build a clause
+ # (a > ?) OR
+ # (a == ? AND b > ?) OR
+ # (a == ? AND b == ? AND c > ?)
+ # ...
+ # (a == ? AND b == ? AND ... AND z > ?)
+ #
+ # or, equivalently:
+ #
+ # (a > ? OR (a == ? AND
+ # (b > ? OR (b == ? AND
+ # ...
+ # (y > ? OR (y == ? AND
+ # z > ?
+ # ))
+ # ...
+ # ))
+ # ))
+ #
+ # which itself is equivalent to (and apparently easier for the query optimiser):
+ #
+ # (a >= ? AND (a > ? OR
+ # (b >= ? AND (b > ? OR
+ # ...
+ # (y >= ? AND (y > ? OR
+ # z > ?
+ # ))
+ # ...
+ # ))
+ # ))
+ #
+ #
+
+ clause = ""
+ args = [] # type: List[KV]
+ for k, v in keys[:-1]:
+ clause = clause + "(%s >= ? AND (%s > ? OR " % (k, k)
+ args.extend([v, v])
+
+ (k, v) = keys[-1]
+ clause += "%s > ?" % (k,)
+ args.append(v)
+
+ clause += "))" * (len(keys) - 1)
+ return clause, args
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 3bc2e8b986..215a949442 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank)
+ db_conn.execute("PRAGMA foreign_keys = ON;")
def is_deadlock(self, error):
return False
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 0f9ac1cf09..f159400a87 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -23,7 +23,6 @@ from typing import Iterable, List, Optional, Set, Tuple
from six import iteritems
from six.moves import range
-import attr
from prometheus_client import Counter, Histogram
from twisted.internet import defer
@@ -35,6 +34,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main.events import DeltaState
from synapse.types import StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -73,22 +73,6 @@ stale_forward_extremities_counter = Histogram(
)
-@attr.s(slots=True)
-class DeltaState:
- """Deltas to use to update the `current_state_events` table.
-
- Attributes:
- to_delete: List of type/state_keys to delete from current state
- to_insert: Map of state to upsert into current state
- no_longer_in_room: The server is not longer in the room, so the room
- should e.g. be removed from `current_state_events` table.
- """
-
- to_delete = attr.ib(type=List[Tuple[str, str]])
- to_insert = attr.ib(type=StateMap[str])
- no_longer_in_room = attr.ib(type=bool, default=False)
-
-
class _EventPeristenceQueue(object):
"""Queues up events so that they can be persisted in bulk with only one
concurrent transaction per room.
@@ -205,6 +189,7 @@ class EventsPersistenceStorage(object):
# store for now.
self.main_store = stores.main
self.state_store = stores.state
+ self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
@@ -445,7 +430,7 @@ class EventsPersistenceStorage(object):
if current_state is not None:
current_state_for_room[room_id] = current_state
- await self.main_store._persist_events_and_state_updates(
+ await self.persist_events_store._persist_events_and_state_updates(
chunk,
current_state_for_room=current_state_for_room,
state_delta_for_room=state_delta_for_room,
@@ -491,13 +476,15 @@ class EventsPersistenceStorage(object):
)
# Remove any events which are prev_events of any existing events.
- existing_prevs = await self.main_store._get_events_which_are_prevs(result)
+ existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
+ result
+ )
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
# events. If they do we need to remove them and their prev events,
# otherwise we end up with dangling extremities.
- existing_prevs = await self.main_store._get_prevs_before_rejected(
+ existing_prevs = await self.persist_events_store._get_prevs_before_rejected(
e_id for event in new_events for e_id in event.prev_event_ids()
)
result.difference_update(existing_prevs)
@@ -753,8 +740,8 @@ class EventsPersistenceStorage(object):
# whose state has changed as we've already their new state above.
users_to_ignore = [
state_key
- for _, state_key in itertools.chain(delta.to_insert, delta.to_delete)
- if self.is_mine_id(state_key)
+ for typ, state_key in itertools.chain(delta.to_insert, delta.to_delete)
+ if typ == EventTypes.Member and self.is_mine_id(state_key)
]
if await self.main_store.is_local_host_in_room_ignoring_users(
@@ -799,3 +786,9 @@ class EventsPersistenceStorage(object):
for user_id in left_users:
await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+ async def locally_reject_invite(self, user_id: str, room_id: str) -> int:
+ """Mark the invite has having been rejected even though we failed to
+ create a leave event for it.
+ """
+ return await self.persist_events_store.locally_reject_invite(user_id, room_id)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 6cb7d4b922..9cc3b51fe6 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -19,17 +19,22 @@ import logging
import os
import re
from collections import Counter
+from typing import TextIO
import attr
from synapse.storage.engines.postgres import PostgresEngine
+from synapse.storage.types import Cursor
logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 57
+# XXX: If you're about to bump this to 59 (or higher) please create an update
+# that drops the unused `cache_invalidation_stream` table, as per #7436!
+# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
+SCHEMA_VERSION = 58
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -362,9 +367,8 @@ def _upgrade_existing_database(
if duplicates:
# We don't support using the same file name in the same delta version.
raise PrepareDatabaseException(
- "Found multiple delta files with the same name in v%d: %s",
- v,
- duplicates,
+ "Found multiple delta files with the same name in v%d: %s"
+ % (v, duplicates,)
)
# We sort to ensure that we apply the delta files in a consistent
@@ -477,8 +481,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
)
logger.info("applying schema %s for %s", name, modname)
- for statement in get_statements(stream):
- cur.execute(statement)
+ execute_statements_from_stream(cur, stream)
# Mark as done.
cur.execute(
@@ -536,8 +539,12 @@ def get_statements(f):
def executescript(txn, schema_path):
with open(schema_path, "r") as f:
- for statement in get_statements(f):
- txn.execute(statement)
+ execute_statements_from_stream(txn, f)
+
+
+def execute_statements_from_stream(cur: Cursor, f: TextIO):
+ for statement in get_statements(f):
+ cur.execute(statement)
def _get_or_create_schema_state(txn, database_engine):
diff --git a/synapse/storage/schema/delta/58/00background_update_ordering.sql b/synapse/storage/schema/delta/58/00background_update_ordering.sql
new file mode 100644
index 0000000000..02dae587cc
--- /dev/null
+++ b/synapse/storage/schema/delta/58/00background_update_ordering.sql
@@ -0,0 +1,19 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* add an "ordering" column to background_updates, which can be used to sort them
+ to achieve some level of consistency. */
+
+ALTER TABLE background_updates ADD COLUMN ordering INT NOT NULL DEFAULT 0;
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9d851beaa5..f89ce0bed2 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -16,6 +16,11 @@
import contextlib
import threading
from collections import deque
+from typing import Dict, Set, Tuple
+
+from typing_extensions import Deque
+
+from synapse.storage.database import Database, LoggingTransaction
class IdGenerator(object):
@@ -87,7 +92,7 @@ class StreamIdGenerator(object):
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[int]
def get_next(self):
"""
@@ -161,9 +166,10 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
+ self._table = table
self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column)
- self._unfinished_ids = deque()
+ self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
def get_next(self):
"""
@@ -198,3 +204,173 @@ class ChainedIdGenerator(object):
return stream_id - 1, chained_id
return self._current_max, self.chained_generator.get_current_token()
+
+ def advance(self, token: int):
+ """Stub implementation for advancing the token when receiving updates
+ over replication; raises an exception as this instance should be the
+ only source of updates.
+ """
+
+ raise Exception(
+ "Attempted to advance token on source for table %r", self._table
+ )
+
+
+class MultiWriterIdGenerator:
+ """An ID generator that tracks a stream that can have multiple writers.
+
+ Uses a Postgres sequence to coordinate ID assignment, but positions of other
+ writers will only get updated when `advance` is called (by replication).
+
+ Note: Only works with Postgres.
+
+ Args:
+ db_conn
+ db
+ instance_name: The name of this instance.
+ table: Database table associated with stream.
+ instance_column: Column that stores the row's writer's instance name
+ id_column: Column that stores the stream ID.
+ sequence_name: The name of the postgres sequence used to generate new
+ IDs.
+ """
+
+ def __init__(
+ self,
+ db_conn,
+ db: Database,
+ instance_name: str,
+ table: str,
+ instance_column: str,
+ id_column: str,
+ sequence_name: str,
+ ):
+ self._db = db
+ self._instance_name = instance_name
+ self._sequence_name = sequence_name
+
+ # We lock as some functions may be called from DB threads.
+ self._lock = threading.Lock()
+
+ self._current_positions = self._load_current_ids(
+ db_conn, table, instance_column, id_column
+ )
+
+ # Set of local IDs that we're still processing. The current position
+ # should be less than the minimum of this set (if not empty).
+ self._unfinished_ids = set() # type: Set[int]
+
+ def _load_current_ids(
+ self, db_conn, table: str, instance_column: str, id_column: str
+ ) -> Dict[str, int]:
+ sql = """
+ SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
+ GROUP BY %(instance)s
+ """ % {
+ "instance": instance_column,
+ "id": id_column,
+ "table": table,
+ }
+
+ cur = db_conn.cursor()
+ cur.execute(sql)
+
+ # `cur` is an iterable over returned rows, which are 2-tuples.
+ current_positions = dict(cur)
+
+ cur.close()
+
+ return current_positions
+
+ def _load_next_id_txn(self, txn):
+ txn.execute("SELECT nextval(?)", (self._sequence_name,))
+ (next_id,) = txn.fetchone()
+ return next_id
+
+ async def get_next(self):
+ """
+ Usage:
+ with await stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
+
+ # Assert the fetched ID is actually greater than what we currently
+ # believe the ID to be. If not, then the sequence and table have got
+ # out of sync somehow.
+ assert self.get_current_token() < next_id
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_id
+ finally:
+ self._mark_id_as_finished(next_id)
+
+ return manager()
+
+ def get_next_txn(self, txn: LoggingTransaction):
+ """
+ Usage:
+
+ stream_id = stream_id_gen.get_next(txn)
+ # ... persist event ...
+ """
+
+ next_id = self._load_next_id_txn(txn)
+
+ with self._lock:
+ self._unfinished_ids.add(next_id)
+
+ txn.call_after(self._mark_id_as_finished, next_id)
+ txn.call_on_exception(self._mark_id_as_finished, next_id)
+
+ return next_id
+
+ def _mark_id_as_finished(self, next_id: int):
+ """The ID has finished being processed so we should advance the
+ current poistion if possible.
+ """
+
+ with self._lock:
+ self._unfinished_ids.discard(next_id)
+
+ # Figure out if its safe to advance the position by checking there
+ # aren't any lower allocated IDs that are yet to finish.
+ if all(c > next_id for c in self._unfinished_ids):
+ curr = self._current_positions.get(self._instance_name, 0)
+ self._current_positions[self._instance_name] = max(curr, next_id)
+
+ def get_current_token(self, instance_name: str = None) -> int:
+ """Gets the current position of a named writer (defaults to current
+ instance).
+
+ Returns 0 if we don't have a position for the named writer (likely due
+ to it being a new writer).
+ """
+
+ if instance_name is None:
+ instance_name = self._instance_name
+
+ with self._lock:
+ return self._current_positions.get(instance_name, 0)
+
+ def get_positions(self) -> Dict[str, int]:
+ """Get a copy of the current positon map.
+ """
+
+ with self._lock:
+ return dict(self._current_positions)
+
+ def advance(self, instance_name: str, new_id: int):
+ """Advance the postion of the named writer to the given ID, if greater
+ than existing entry.
+ """
+
+ with self._lock:
+ self._current_positions[instance_name] = max(
+ new_id, self._current_positions.get(instance_name, 0)
+ )
|