diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 71316f7d09..ec89f645d4 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018,2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,506 +14,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import calendar
-import logging
-import time
-
-from twisted.internet import defer
-
-from synapse.api.constants import PresenceState
-from synapse.storage.devices import DeviceStore
-from synapse.storage.user_erasure_store import UserErasureStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-from .account_data import AccountDataStore
-from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .client_ips import ClientIpStore
-from .deviceinbox import DeviceInboxStore
-from .directory import DirectoryStore
-from .e2e_room_keys import EndToEndRoomKeyStore
-from .end_to_end_keys import EndToEndKeyStore
-from .engines import PostgresEngine
-from .event_federation import EventFederationStore
-from .event_push_actions import EventPushActionsStore
-from .events import EventsStore
-from .events_bg_updates import EventsBackgroundUpdatesStore
-from .filtering import FilteringStore
-from .group_server import GroupServerStore
-from .keys import KeyStore
-from .media_repository import MediaRepositoryStore
-from .monthly_active_users import MonthlyActiveUsersStore
-from .openid import OpenIdStore
-from .presence import PresenceStore, UserPresenceState
-from .profile import ProfileStore
-from .push_rule import PushRuleStore
-from .pusher import PusherStore
-from .receipts import ReceiptsStore
-from .registration import RegistrationStore
-from .rejections import RejectionsStore
-from .relations import RelationsStore
-from .room import RoomStore
-from .roommember import RoomMemberStore
-from .search import SearchStore
-from .signatures import SignatureStore
-from .state import StateStore
-from .stats import StatsStore
-from .stream import StreamStore
-from .tags import TagsStore
-from .transactions import TransactionStore
-from .user_directory import UserDirectoryStore
-from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator
-
-logger = logging.getLogger(__name__)
-
-
-class DataStore(
- EventsBackgroundUpdatesStore,
- RoomMemberStore,
- RoomStore,
- RegistrationStore,
- StreamStore,
- ProfileStore,
- PresenceStore,
- TransactionStore,
- DirectoryStore,
- KeyStore,
- StateStore,
- SignatureStore,
- ApplicationServiceStore,
- EventsStore,
- EventFederationStore,
- MediaRepositoryStore,
- RejectionsStore,
- FilteringStore,
- PusherStore,
- PushRuleStore,
- ApplicationServiceTransactionStore,
- ReceiptsStore,
- EndToEndKeyStore,
- EndToEndRoomKeyStore,
- SearchStore,
- TagsStore,
- AccountDataStore,
- EventPushActionsStore,
- OpenIdStore,
- ClientIpStore,
- DeviceStore,
- DeviceInboxStore,
- UserDirectoryStore,
- GroupServerStore,
- UserErasureStore,
- MonthlyActiveUsersStore,
- StatsStore,
- RelationsStore,
-):
- def __init__(self, db_conn, hs):
- self.hs = hs
- self._clock = hs.get_clock()
- self.database_engine = hs.database_engine
-
- self._stream_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- extra_tables=[("local_invites", "stream_id")],
- )
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
- )
- self._presence_id_gen = StreamIdGenerator(
- db_conn, "presence_stream", "stream_id"
- )
- self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_max_stream_id", "stream_id"
- )
- self._public_room_id_gen = StreamIdGenerator(
- db_conn, "public_room_list_stream", "stream_id"
- )
- self._device_list_id_gen = StreamIdGenerator(
- db_conn, "device_lists_stream", "stream_id"
- )
-
- self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
- self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
- self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
- self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- )
- self._pushers_id_gen = StreamIdGenerator(
- db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
- )
- self._group_updates_id_gen = StreamIdGenerator(
- db_conn, "local_group_updates", "stream_id"
- )
-
- if isinstance(self.database_engine, PostgresEngine):
- self._cache_id_gen = StreamIdGenerator(
- db_conn, "cache_invalidation_stream", "stream_id"
- )
- else:
- self._cache_id_gen = None
-
- self._presence_on_startup = self._get_active_presence(db_conn)
-
- presence_cache_prefill, min_presence_val = self._get_cache_dict(
- db_conn,
- "presence_stream",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self._presence_id_gen.get_current_token(),
- )
- self.presence_stream_cache = StreamChangeCache(
- "PresenceStreamChangeCache",
- min_presence_val,
- prefilled_cache=presence_cache_prefill,
- )
-
- max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
- device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
- db_conn,
- "device_inbox",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=max_device_inbox_id,
- limit=1000,
- )
- self._device_inbox_stream_cache = StreamChangeCache(
- "DeviceInboxStreamChangeCache",
- min_device_inbox_id,
- prefilled_cache=device_inbox_prefill,
- )
- # The federation outbox and the local device inbox uses the same
- # stream_id generator.
- device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
- db_conn,
- "device_federation_outbox",
- entity_column="destination",
- stream_column="stream_id",
- max_value=max_device_inbox_id,
- limit=1000,
- )
- self._device_federation_outbox_stream_cache = StreamChangeCache(
- "DeviceFederationOutboxStreamChangeCache",
- min_device_outbox_id,
- prefilled_cache=device_outbox_prefill,
- )
-
- device_list_max = self._device_list_id_gen.get_current_token()
- self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache", device_list_max
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max
- )
-
- events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
- db_conn,
- "current_state_delta_stream",
- entity_column="room_id",
- stream_column="stream_id",
- max_value=events_max, # As we share the stream id with events token
- limit=1000,
- )
- self._curr_state_delta_stream_cache = StreamChangeCache(
- "_curr_state_delta_stream_cache",
- min_curr_state_delta_id,
- prefilled_cache=curr_state_delta_prefill,
- )
-
- _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
- db_conn,
- "local_group_updates",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self._group_updates_id_gen.get_current_token(),
- limit=1000,
- )
- self._group_updates_stream_cache = StreamChangeCache(
- "_group_updates_stream_cache",
- min_group_updates_id,
- prefilled_cache=_group_updates_prefill,
- )
-
- self._stream_order_on_start = self.get_room_max_stream_ordering()
- self._min_stream_order_on_start = self.get_room_min_stream_ordering()
-
- # Used in _generate_user_daily_visits to keep track of progress
- self._last_user_visit_update = self._get_start_of_day()
-
- super(DataStore, self).__init__(db_conn, hs)
-
- def take_presence_startup_info(self):
- active_on_startup = self._presence_on_startup
- self._presence_on_startup = None
- return active_on_startup
-
- def _get_active_presence(self, db_conn):
- """Fetch non-offline presence from the database so that we can register
- the appropriate time outs.
- """
-
- sql = (
- "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
- " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
- " WHERE state != ?"
- )
- sql = self.database_engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
- txn.execute(sql, (PresenceState.OFFLINE,))
- rows = self.cursor_to_dict(txn)
- txn.close()
-
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- return [UserPresenceState(**row) for row in rows]
-
- def count_daily_users(self):
- """
- Counts the number of users who used this homeserver in the last 24 hours.
- """
-
- def _count_users(txn):
- yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT user_id FROM user_ips
- WHERE last_seen > ?
- GROUP BY user_id
- ) u
- """
-
- txn.execute(sql, (yesterday,))
- count, = txn.fetchone()
- return count
-
- return self.runInteraction("count_users", _count_users)
-
- def count_r30_users(self):
- """
- Counts the number of 30 day retained users, defined as:-
- * Users who have created their accounts more than 30 days ago
- * Where last seen at most 30 days ago
- * Where account creation and last_seen are > 30 days apart
-
- Returns counts globaly for a given user as well as breaking
- by platform
- """
-
- def _count_r30_users(txn):
- thirty_days_in_secs = 86400 * 30
- now = int(self._clock.time())
- thirty_days_ago_in_secs = now - thirty_days_in_secs
-
- sql = """
- SELECT platform, COALESCE(count(*), 0) FROM (
- SELECT
- users.name, platform, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen,
- CASE
- WHEN user_agent LIKE '%%Android%%' THEN 'android'
- WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
- WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
- WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
- WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
- ELSE 'unknown'
- END
- AS platform
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND users.appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, platform, users.creation_ts
- ) u GROUP BY platform
- """
-
- results = {}
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- for row in txn:
- if row[0] == 'unknown':
- pass
- results[row[0]] = row[1]
-
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT users.name, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, users.creation_ts
- ) u
- """
-
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- count, = txn.fetchone()
- results['all'] = count
-
- return results
-
- return self.runInteraction("count_r30_users", _count_r30_users)
-
- def _get_start_of_day(self):
- """
- Returns millisecond unixtime for start of UTC day.
- """
- now = time.gmtime()
- today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
- return today_start * 1000
-
- def generate_user_daily_visits(self):
- """
- Generates daily visit data for use in cohort/ retention analysis
- """
-
- def _generate_user_daily_visits(txn):
- logger.info("Calling _generate_user_daily_visits")
- today_start = self._get_start_of_day()
- a_day_in_milliseconds = 24 * 60 * 60 * 1000
- now = self.clock.time_msec()
-
- sql = """
- INSERT INTO user_daily_visits (user_id, device_id, timestamp)
- SELECT u.user_id, u.device_id, ?
- FROM user_ips AS u
- LEFT JOIN (
- SELECT user_id, device_id, timestamp FROM user_daily_visits
- WHERE timestamp = ?
- ) udv
- ON u.user_id = udv.user_id AND u.device_id=udv.device_id
- INNER JOIN users ON users.name=u.user_id
- WHERE last_seen > ? AND last_seen <= ?
- AND udv.timestamp IS NULL AND users.is_guest=0
- AND users.appservice_id IS NULL
- GROUP BY u.user_id, u.device_id
- """
-
- # This means that the day has rolled over but there could still
- # be entries from the previous day. There is an edge case
- # where if the user logs in at 23:59 and overwrites their
- # last_seen at 00:01 then they will not be counted in the
- # previous day's stats - it is important that the query is run
- # often to minimise this case.
- if today_start > self._last_user_visit_update:
- yesterday_start = today_start - a_day_in_milliseconds
- txn.execute(
- sql,
- (
- yesterday_start,
- yesterday_start,
- self._last_user_visit_update,
- today_start,
- ),
- )
- self._last_user_visit_update = today_start
-
- txn.execute(
- sql, (today_start, today_start, self._last_user_visit_update, now)
- )
- # Update _last_user_visit_update to now. The reason to do this
- # rather just clamping to the beginning of the day is to limit
- # the size of the join - meaning that the query can be run more
- # frequently
- self._last_user_visit_update = now
-
- return self.runInteraction(
- "generate_user_daily_visits", _generate_user_daily_visits
- )
-
- def get_users(self):
- """Function to reterive a list of users in users table.
-
- Args:
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self._simple_select_list(
- table="users",
- keyvalues={},
- retcols=["name", "password_hash", "is_guest", "admin"],
- desc="get_users",
- )
-
- @defer.inlineCallbacks
- def get_users_paginate(self, order, start, limit):
- """Function to reterive a paginated list of users from
- users list. This will return a json object, which contains
- list of users and the total number of users in users table.
-
- Args:
- order (str): column name to order the select by this column
- start (int): start number to begin the query from
- limit (int): number of rows to reterive
- Returns:
- defer.Deferred: resolves to json object {list[dict[str, Any]], count}
- """
- users = yield self.runInteraction(
- "get_users_paginate",
- self._simple_select_list_paginate_txn,
- table="users",
- keyvalues={"is_guest": False},
- orderby=order,
- start=start,
- limit=limit,
- retcols=["name", "password_hash", "is_guest", "admin"],
- )
- count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
- retval = {"users": users, "total": count}
- defer.returnValue(retval)
-
- def search_users(self, term):
- """Function to search users list for one or more users with
- the matched term.
-
- Args:
- term (str): search term
- col (str): column to query term should be matched to
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self._simple_search_list(
- table="users",
- term=term,
- col="name",
- retcols=["name", "password_hash", "is_guest", "admin"],
- desc="search_users",
- )
-
-
-def are_all_users_on_domain(txn, database_engine, domain):
- sql = database_engine.convert_param_style(
- "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
- )
- pat = "%:" + domain
- txn.execute(sql, (pat,))
- num_not_matching = txn.fetchall()[0][0]
- if num_not_matching == 0:
- return True
- return False
+"""
+The storage layer is split up into multiple parts to allow Synapse to run
+against different configurations of databases (e.g. single or multiple
+databases). The `Database` class represents a single physical database. The
+`data_stores` are classes that talk directly to a `Database` instance and have
+associated schemas, background updates, etc. On top of those there are classes
+that provide high level interfaces that combine calls to multiple `data_stores`.
+
+There are also schemas that get applied to every database, regardless of the
+data stores associated with them (e.g. the schema version tables), which are
+stored in `synapse.storage.schema`.
+"""
+
+from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main import DataStore
+from synapse.storage.persist_events import EventsPersistenceStorage
+from synapse.storage.purge_events import PurgeEventsStorage
+from synapse.storage.state import StateGroupStorage
+
+__all__ = ["DataStores", "DataStore"]
+
+
+class Storage(object):
+ """The high level interfaces for talking to various storage layers.
+ """
+
+ def __init__(self, hs, stores: DataStores):
+ # We include the main data store here mainly so that we don't have to
+ # rewrite all the existing code to split it into high vs low level
+ # interfaces.
+ self.main = stores.main
+
+ self.persistence = EventsPersistenceStorage(hs, stores)
+ self.purge_events = PurgeEventsStorage(hs, stores)
+ self.state = StateGroupStorage(hs, stores)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 537696547c..13de5f1f62 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,1382 +14,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
import random
-import sys
-import threading
-import time
+from abc import ABCMeta
+from typing import Any, Optional
-from six import PY2, iteritems, iterkeys, itervalues
-from six.moves import builtins, intern, range
+from six import PY2
+from six.moves import builtins
from canonicaljson import json
-from prometheus_client import Histogram
-from twisted.internet import defer
-
-from synapse.api.errors import StoreError
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import get_domain_from_id
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import Cache
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.stringutils import exception_to_unicode
+from synapse.storage.database import LoggingTransaction # noqa: F401
+from synapse.storage.database import make_in_list_sql_clause # noqa: F401
+from synapse.storage.database import Database
+from synapse.types import Collection, get_domain_from_id
logger = logging.getLogger(__name__)
-try:
- MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
- # python 3 does not have a maximum int value
- MAX_TXN_ID = 2 ** 63 - 1
-
-sql_logger = logging.getLogger("synapse.storage.SQL")
-transaction_logger = logging.getLogger("synapse.storage.txn")
-perf_logger = logging.getLogger("synapse.storage.TIME")
-
-sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
-
-sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
-sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
-
-
-# Unique indexes which have been added in background updates. Maps from table name
-# to the name of the background update which added the unique index to that table.
-#
-# This is used by the upsert logic to figure out which tables are safe to do a proper
-# UPSERT on: until the relevant background update has completed, we
-# have to emulate an upsert by locking the table.
-#
-UNIQUE_INDEX_BACKGROUND_UPDATES = {
- "user_ips": "user_ips_device_unique_index",
- "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
- "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
- "event_search": "event_search_event_id_idx",
-}
-
-# This is a special cache name we use to batch multiple invalidations of caches
-# based on the current state when notifying workers over replication.
-_CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-
-
-class LoggingTransaction(object):
- """An object that almost-transparently proxies for the 'txn' object
- passed to the constructor. Adds logging and metrics to the .execute()
- method."""
-
- __slots__ = [
- "txn",
- "name",
- "database_engine",
- "after_callbacks",
- "exception_callbacks",
- ]
-
- def __init__(
- self, txn, name, database_engine, after_callbacks, exception_callbacks
- ):
- object.__setattr__(self, "txn", txn)
- object.__setattr__(self, "name", name)
- object.__setattr__(self, "database_engine", database_engine)
- object.__setattr__(self, "after_callbacks", after_callbacks)
- object.__setattr__(self, "exception_callbacks", exception_callbacks)
-
- def call_after(self, callback, *args, **kwargs):
- """Call the given callback on the main twisted thread after the
- transaction has finished. Used to invalidate the caches on the
- correct thread.
- """
- self.after_callbacks.append((callback, args, kwargs))
-
- def call_on_exception(self, callback, *args, **kwargs):
- self.exception_callbacks.append((callback, args, kwargs))
-
- def __getattr__(self, name):
- return getattr(self.txn, name)
-
- def __setattr__(self, name, value):
- setattr(self.txn, name, value)
-
- def __iter__(self):
- return self.txn.__iter__()
-
- def execute_batch(self, sql, args):
- if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch
-
- self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
- else:
- for val in args:
- self.execute(sql, val)
-
- def execute(self, sql, *args):
- self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql, *args):
- self._do_execute(self.txn.executemany, sql, *args)
+# some of our subclasses have abstract methods, so we use the ABCMeta metaclass.
+class SQLBaseStore(metaclass=ABCMeta):
+ """Base class for data stores that holds helper functions.
- def _make_sql_one_line(self, sql):
- "Strip newlines out of SQL so that the loggers in the DB are on one line"
- return " ".join(l.strip() for l in sql.splitlines() if l.strip())
-
- def _do_execute(self, func, sql, *args):
- sql = self._make_sql_one_line(sql)
-
- # TODO(paul): Maybe use 'info' and 'debug' for values?
- sql_logger.debug("[SQL] {%s} %s", self.name, sql)
-
- sql = self.database_engine.convert_param_style(sql)
- if args:
- try:
- sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
- except Exception:
- # Don't let logging failures stop SQL from working
- pass
-
- start = time.time()
-
- try:
- return func(sql, *args)
- except Exception as e:
- logger.debug("[SQL FAIL] {%s} %s", self.name, e)
- raise
- finally:
- secs = time.time() - start
- sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
- sql_query_timer.labels(sql.split()[0]).observe(secs)
-
-
-class PerformanceCounters(object):
- def __init__(self):
- self.current_counters = {}
- self.previous_counters = {}
-
- def update(self, key, start_time, end_time=None):
- if end_time is None:
- end_time = time.time()
- duration = end_time - start_time
- count, cum_time = self.current_counters.get(key, (0, 0))
- count += 1
- cum_time += duration
- self.current_counters[key] = (count, cum_time)
- return end_time
-
- def interval(self, interval_duration, limit=3):
- counters = []
- for name, (count, cum_time) in iteritems(self.current_counters):
- prev_count, prev_time = self.previous_counters.get(name, (0, 0))
- counters.append(
- ((cum_time - prev_time) / interval_duration, count - prev_count, name)
- )
-
- self.previous_counters = dict(self.current_counters)
-
- counters.sort(reverse=True)
-
- top_n_counters = ", ".join(
- "%s(%d): %.3f%%" % (name, count, 100 * ratio)
- for ratio, count, name in counters[:limit]
- )
-
- return top_n_counters
-
-
-class SQLBaseStore(object):
- _TXN_ID = 0
+ Note that multiple instances of this class will exist as there will be one
+ per data store (and not one per physical database).
+ """
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
- self._db_pool = hs.get_db_pool()
-
- self._previous_txn_total_time = 0
- self._current_txn_total_time = 0
- self._previous_loop_ts = 0
-
- # TODO(paul): These can eventually be removed once the metrics code
- # is running in mainline, and we have some nice monitoring frontends
- # to watch it
- self._txn_perf_counters = PerformanceCounters()
- self._get_event_counters = PerformanceCounters()
-
- self._get_event_cache = Cache(
- "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
- )
-
- self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
- self._event_fetch_ongoing = 0
-
- self._pending_ds = []
-
- self.database_engine = hs.database_engine
-
- # A set of tables that are not safe to use native upserts in.
- self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
-
- self._account_validity = self.hs.config.account_validity
-
- # We add the user_directory_search table to the blacklist on SQLite
- # because the existing search table does not have an index, making it
- # unsafe to use native upserts.
- if isinstance(self.database_engine, Sqlite3Engine):
- self._unsafe_to_upsert_tables.add("user_directory_search")
-
- if self.database_engine.can_native_upsert:
- # Check ASAP (and then later, every 1s) to see if we have finished
- # background updates of tables that aren't safe to update.
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "upsert_safety_check",
- self._check_safe_to_upsert,
- )
-
+ self.database_engine = database.engine
+ self.db = database
self.rand = random.SystemRandom()
- if self._account_validity.enabled:
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "account_validity_set_expiration_dates",
- self._set_expiration_date_when_missing,
- )
-
- @defer.inlineCallbacks
- def _check_safe_to_upsert(self):
- """
- Is it safe to use native UPSERT?
-
- If there are background updates, we will need to wait, as they may be
- the addition of indexes that set the UNIQUE constraint that we require.
-
- If the background updates have not completed, wait 15 sec and check again.
- """
- updates = yield self._simple_select_list(
- "background_updates",
- keyvalues=None,
- retcols=["update_name"],
- desc="check_background_updates",
- )
- updates = [x["update_name"] for x in updates]
-
- for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
- if update_name not in updates:
- logger.debug("Now safe to upsert in %s", table)
- self._unsafe_to_upsert_tables.discard(table)
-
- # If there's any updates still running, reschedule to run.
- if updates:
- self._clock.call_later(
- 15.0,
- run_as_background_process,
- "upsert_safety_check",
- self._check_safe_to_upsert,
- )
-
- @defer.inlineCallbacks
- def _set_expiration_date_when_missing(self):
- """
- Retrieves the list of registered users that don't have an expiration date, and
- adds an expiration date for each of them.
- """
-
- def select_users_with_no_expiration_date_txn(txn):
- """Retrieves the list of registered users with no expiration date from the
- database, filtering out deactivated users.
- """
- sql = (
- "SELECT users.name FROM users"
- " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
- )
- txn.execute(sql, [])
-
- res = self.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn,
- user["name"],
- use_delta=True,
- )
-
- yield self.runInteraction(
- "get_users_with_no_expiration_date",
- select_users_with_no_expiration_date_txn,
- )
-
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
- """Sets an expiration date to the account with the given user ID.
-
- Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
- now + validity period. If set to True, this expiration date will be a
- random value in the [now + period - d ; now + period] range, d being a
- delta equal to 10% of the validity period.
- """
- now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
-
- if use_delta:
- expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
- expiration_ts,
- )
-
- self._simple_upsert_txn(
- txn,
- "account_validity",
- keyvalues={"user_id": user_id, },
- values={"expiration_ts_ms": expiration_ts, "email_sent": False, },
- )
-
- def start_profiling(self):
- self._previous_loop_ts = self._clock.time_msec()
-
- def loop():
- curr = self._current_txn_total_time
- prev = self._previous_txn_total_time
- self._previous_txn_total_time = curr
-
- time_now = self._clock.time_msec()
- time_then = self._previous_loop_ts
- self._previous_loop_ts = time_now
-
- ratio = (curr - prev) / (time_now - time_then)
-
- top_three_counters = self._txn_perf_counters.interval(
- time_now - time_then, limit=3
- )
-
- top_3_event_counters = self._get_event_counters.interval(
- time_now - time_then, limit=3
- )
-
- perf_logger.info(
- "Total database time: %.3f%% {%s} {%s}",
- ratio * 100,
- top_three_counters,
- top_3_event_counters,
- )
-
- self._clock.looping_call(loop, 10000)
-
- def _new_transaction(
- self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
- ):
- start = time.time()
- txn_id = self._TXN_ID
-
- # We don't really need these to be unique, so lets stop it from
- # growing really large.
- self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
-
- name = "%s-%x" % (desc, txn_id)
-
- transaction_logger.debug("[TXN START] {%s}", name)
-
- try:
- i = 0
- N = 5
- while True:
- try:
- txn = conn.cursor()
- txn = LoggingTransaction(
- txn,
- name,
- self.database_engine,
- after_callbacks,
- exception_callbacks,
- )
- r = func(txn, *args, **kwargs)
- conn.commit()
- return r
- except self.database_engine.module.OperationalError as e:
- # This can happen if the database disappears mid
- # transaction.
- logger.warning(
- "[TXN OPERROR] {%s} %s %d/%d",
- name,
- exception_to_unicode(e),
- i,
- N,
- )
- if i < N:
- i += 1
- try:
- conn.rollback()
- except self.database_engine.module.Error as e1:
- logger.warning(
- "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
- )
- continue
- raise
- except self.database_engine.module.DatabaseError as e:
- if self.database_engine.is_deadlock(e):
- logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
- if i < N:
- i += 1
- try:
- conn.rollback()
- except self.database_engine.module.Error as e1:
- logger.warning(
- "[TXN EROLL] {%s} %s",
- name,
- exception_to_unicode(e1),
- )
- continue
- raise
- except Exception as e:
- logger.debug("[TXN FAIL] {%s} %s", name, e)
- raise
- finally:
- end = time.time()
- duration = end - start
-
- LoggingContext.current_context().add_database_transaction(duration)
-
- transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
-
- self._current_txn_total_time += duration
- self._txn_perf_counters.update(desc, start, end)
- sql_txn_timer.labels(desc).observe(duration)
-
- @defer.inlineCallbacks
- def runInteraction(self, desc, func, *args, **kwargs):
- """Starts a transaction on the database and runs a given function
-
- Arguments:
- desc (str): description of the transaction, for logging and metrics
- func (func): callback function, which will be called with a
- database transaction (twisted.enterprise.adbapi.Transaction) as
- its first argument, followed by `args` and `kwargs`.
-
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
-
- Returns:
- Deferred: The result of func
- """
- after_callbacks = []
- exception_callbacks = []
-
- if LoggingContext.current_context() == LoggingContext.sentinel:
- logger.warn("Starting db txn '%s' from sentinel context", desc)
-
- try:
- result = yield self.runWithConnection(
- self._new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- **kwargs
- )
-
- for after_callback, after_args, after_kwargs in after_callbacks:
- after_callback(*after_args, **after_kwargs)
- except: # noqa: E722, as we reraise the exception this is fine.
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
- raise
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runWithConnection() method on the underlying db_pool.
-
- Arguments:
- func (func): callback function, which will be called with a
- database connection (twisted.enterprise.adbapi.Connection) as
- its first argument, followed by `args` and `kwargs`.
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
-
- Returns:
- Deferred: The result of func
- """
- parent_context = LoggingContext.current_context()
- if parent_context == LoggingContext.sentinel:
- logger.warn(
- "Starting db connection from sentinel context: metrics will be lost"
- )
- parent_context = None
-
- start_time = time.time()
-
- def inner_func(conn, *args, **kwargs):
- with LoggingContext("runWithConnection", parent_context) as context:
- sched_duration_sec = time.time() - start_time
- sql_scheduling_timer.observe(sched_duration_sec)
- context.add_database_scheduled(sched_duration_sec)
-
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
-
- return func(conn, *args, **kwargs)
-
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
-
- defer.returnValue(result)
-
- @staticmethod
- def cursor_to_dict(cursor):
- """Converts a SQL cursor into an list of dicts.
-
- Args:
- cursor : The DBAPI cursor which has executed a query.
- Returns:
- A list of dicts where the key is the column header.
- """
- col_headers = list(intern(str(column[0])) for column in cursor.description)
- results = list(dict(zip(col_headers, row)) for row in cursor)
- return results
-
- def _execute(self, desc, decoder, query, *args):
- """Runs a single query for a result set.
-
- Args:
- decoder - The function which can resolve the cursor results to
- something meaningful.
- query - The query string to execute
- *args - Query args.
- Returns:
- The result of decoder(results)
- """
-
- def interaction(txn):
- txn.execute(query, args)
- if decoder:
- return decoder(txn)
- else:
- return txn.fetchall()
-
- return self.runInteraction(desc, interaction)
-
- # "Simple" SQL API methods that operate on a single table with no JOINs,
- # no complex WHERE clauses, just a dict of values for columns.
-
- @defer.inlineCallbacks
- def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
- """Executes an INSERT query on the named table.
-
- Args:
- table : string giving the table name
- values : dict of new column names and values for them
- or_ignore : bool stating whether an exception should be raised
- when a conflicting row already exists. If True, False will be
- returned by the function instead
- desc : string giving a description of the transaction
-
- Returns:
- bool: Whether the row was inserted or not. Only useful when
- `or_ignore` is True
- """
- try:
- yield self.runInteraction(desc, self._simple_insert_txn, table, values)
- except self.database_engine.module.IntegrityError:
- # We have to do or_ignore flag at this layer, since we can't reuse
- # a cursor after we receive an error from the db.
- if not or_ignore:
- raise
- defer.returnValue(False)
- defer.returnValue(True)
-
- @staticmethod
- def _simple_insert_txn(txn, table, values):
- keys, vals = zip(*values.items())
-
- sql = "INSERT INTO %s (%s) VALUES(%s)" % (
- table,
- ", ".join(k for k in keys),
- ", ".join("?" for _ in keys),
- )
-
- txn.execute(sql, vals)
-
- def _simple_insert_many(self, table, values, desc):
- return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
-
- @staticmethod
- def _simple_insert_many_txn(txn, table, values):
- if not values:
- return
-
- # This is a *slight* abomination to get a list of tuples of key names
- # and a list of tuples of value names.
- #
- # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
- # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
- #
- # The sort is to ensure that we don't rely on dictionary iteration
- # order.
- keys, vals = zip(
- *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
- )
-
- for k in keys:
- if k != keys[0]:
- raise RuntimeError("All items must have the same keys")
-
- sql = "INSERT INTO %s (%s) VALUES(%s)" % (
- table,
- ", ".join(k for k in keys[0]),
- ", ".join("?" for _ in keys[0]),
- )
-
- txn.executemany(sql, vals)
-
- @defer.inlineCallbacks
- def _simple_upsert(
- self,
- table,
- keyvalues,
- values,
- insertion_values={},
- desc="_simple_upsert",
- lock=True,
- ):
- """
-
- `lock` should generally be set to True (the default), but can be set
- to False if either of the following are true:
-
- * there is a UNIQUE INDEX on the key columns. In this case a conflict
- will cause an IntegrityError in which case this function will retry
- the update.
-
- * we somehow know that we are the only thread which will be updating
- this table.
-
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key columns and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- Deferred(None or bool): Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
- """
- attempts = 0
- while True:
- try:
- result = yield self.runInteraction(
- desc,
- self._simple_upsert_txn,
- table,
- keyvalues,
- values,
- insertion_values,
- lock=lock,
- )
- defer.returnValue(result)
- except self.database_engine.module.IntegrityError as e:
- attempts += 1
- if attempts >= 5:
- # don't retry forever, because things other than races
- # can cause IntegrityErrors
- raise
-
- # presumably we raced with another transaction: let's retry.
- logger.warn(
- "IntegrityError when upserting into %s; retrying: %s", table, e
- )
-
- def _simple_upsert_txn(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
- """
- Pick the UPSERT method which works best on the platform. Either the
- native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
-
- Args:
- txn: The transaction to use.
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
- """
- if (
- self.database_engine.can_native_upsert
- and table not in self._unsafe_to_upsert_tables
- ):
- return self._simple_upsert_txn_native_upsert(
- txn, table, keyvalues, values, insertion_values=insertion_values
- )
- else:
- return self._simple_upsert_txn_emulated(
- txn,
- table,
- keyvalues,
- values,
- insertion_values=insertion_values,
- lock=lock,
- )
-
- def _simple_upsert_txn_emulated(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
- """
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- bool: Return True if a new entry was created, False if an existing
- one was updated.
- """
- # We need to lock the table :(, unless we're *really* careful
- if lock:
- self.database_engine.lock_table(txn, table)
-
- def _getwhere(key):
- # If the value we're passing in is None (aka NULL), we need to use
- # IS, not =, as NULL = NULL equals NULL (False).
- if keyvalues[key] is None:
- return "%s IS ?" % (key,)
- else:
- return "%s = ?" % (key,)
-
- if not values:
- # If `values` is empty, then all of the values we care about are in
- # the unique key, so there is nothing to UPDATE. We can just do a
- # SELECT instead to see if it exists.
- sql = "SELECT 1 FROM %s WHERE %s" % (
- table,
- " AND ".join(_getwhere(k) for k in keyvalues),
- )
- sqlargs = list(keyvalues.values())
- txn.execute(sql, sqlargs)
- if txn.fetchall():
- # We have an existing record.
- return False
- else:
- # First try to update.
- sql = "UPDATE %s SET %s WHERE %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in values),
- " AND ".join(_getwhere(k) for k in keyvalues),
- )
- sqlargs = list(values.values()) + list(keyvalues.values())
-
- txn.execute(sql, sqlargs)
- if txn.rowcount > 0:
- # successfully updated at least one row.
- return False
-
- # We didn't find any existing rows, so insert a new one
- allvalues = {}
- allvalues.update(keyvalues)
- allvalues.update(values)
- allvalues.update(insertion_values)
-
- sql = "INSERT INTO %s (%s) VALUES (%s)" % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues),
- )
- txn.execute(sql, list(allvalues.values()))
- # successfully inserted
- return True
-
- def _simple_upsert_txn_native_upsert(
- self, txn, table, keyvalues, values, insertion_values={}
- ):
- """
- Use the native UPSERT functionality in recent PostgreSQL versions.
-
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- Returns:
- None
- """
- allvalues = {}
- allvalues.update(keyvalues)
- allvalues.update(insertion_values)
-
- if not values:
- latter = "NOTHING"
- else:
- allvalues.update(values)
- latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
-
- sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues),
- ", ".join(k for k in keyvalues),
- latter,
- )
- txn.execute(sql, list(allvalues.values()))
-
- def _simple_upsert_many_txn(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times.
-
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- if (
- self.database_engine.can_native_upsert
- and table not in self._unsafe_to_upsert_tables
- ):
- return self._simple_upsert_many_txn_native_upsert(
- txn, table, key_names, key_values, value_names, value_values
- )
- else:
- return self._simple_upsert_many_txn_emulated(
- txn, table, key_names, key_values, value_names, value_values
- )
-
- def _simple_upsert_many_txn_emulated(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times, but without native UPSERT support or batching.
-
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- # No value columns, therefore make a blank list so that the following
- # zip() works correctly.
- if not value_names:
- value_values = [() for x in range(len(key_values))]
-
- for keyv, valv in zip(key_values, value_values):
- _keys = {x: y for x, y in zip(key_names, keyv)}
- _vals = {x: y for x, y in zip(value_names, valv)}
-
- self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
-
- def _simple_upsert_many_txn_native_upsert(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times, using batching where possible.
-
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- allnames = []
- allnames.extend(key_names)
- allnames.extend(value_names)
-
- if not value_names:
- # No value columns, therefore make a blank list so that the
- # following zip() works correctly.
- latter = "NOTHING"
- value_values = [() for x in range(len(key_values))]
- else:
- latter = "UPDATE SET " + ", ".join(
- k + "=EXCLUDED." + k for k in value_names
- )
-
- sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
- table,
- ", ".join(k for k in allnames),
- ", ".join("?" for _ in allnames),
- ", ".join(key_names),
- latter,
- )
-
- args = []
-
- for x, y in zip(key_values, value_values):
- args.append(tuple(x) + tuple(y))
-
- return txn.execute_batch(sql, args)
-
- def _simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
- ):
- """Executes a SELECT query on the named table, which is expected to
- return a single row, returning multiple columns from it.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcols : list of strings giving the names of the columns to return
-
- allow_none : If true, return None instead of failing if the SELECT
- statement returns no rows
- """
- return self.runInteraction(
- desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
- )
-
- def _simple_select_one_onecol(
- self,
- table,
- keyvalues,
- retcol,
- allow_none=False,
- desc="_simple_select_one_onecol",
- ):
- """Executes a SELECT query on the named table, which is expected to
- return a single row, returning a single column from it.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcol : string giving the name of the column to return
- """
- return self.runInteraction(
- desc,
- self._simple_select_one_onecol_txn,
- table,
- keyvalues,
- retcol,
- allow_none=allow_none,
- )
-
- @classmethod
- def _simple_select_one_onecol_txn(
- cls, txn, table, keyvalues, retcol, allow_none=False
- ):
- ret = cls._simple_select_onecol_txn(
- txn, table=table, keyvalues=keyvalues, retcol=retcol
- )
-
- if ret:
- return ret[0]
- else:
- if allow_none:
- return None
- else:
- raise StoreError(404, "No row found")
-
- @staticmethod
- def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
- sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
-
- if keyvalues:
- sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
- txn.execute(sql, list(keyvalues.values()))
- else:
- txn.execute(sql)
-
- return [r[0] for r in txn]
-
- def _simple_select_onecol(
- self, table, keyvalues, retcol, desc="_simple_select_onecol"
- ):
- """Executes a SELECT query on the named table, which returns a list
- comprising of the values of the named column from the selected rows.
-
- Args:
- table (str): table name
- keyvalues (dict|None): column names and values to select the rows with
- retcol (str): column whos value we wish to retrieve.
-
- Returns:
- Deferred: Results in a list
- """
- return self.runInteraction(
- desc, self._simple_select_onecol_txn, table, keyvalues, retcol
- )
-
- def _simple_select_list(
- self, table, keyvalues, retcols, desc="_simple_select_list"
- ):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc, self._simple_select_list_txn, table, keyvalues, retcols
- )
-
- @classmethod
- def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
- """
- if keyvalues:
- sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
- txn.execute(sql, list(keyvalues.values()))
- else:
- sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
- txn.execute(sql)
-
- return cls.cursor_to_dict(txn)
-
- @defer.inlineCallbacks
- def _simple_select_many_batch(
- self,
- table,
- column,
- iterable,
- retcols,
- keyvalues={},
- desc="_simple_select_many_batch",
- batch_size=100,
- ):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Filters rows by if value of `column` is in `iterable`.
-
- Args:
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
- """
- results = []
-
- if not iterable:
- defer.returnValue(results)
-
- # iterables can not be sliced, so convert it to a list first
- it_list = list(iterable)
-
- chunks = [
- it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
- ]
- for chunk in chunks:
- rows = yield self.runInteraction(
- desc,
- self._simple_select_many_txn,
- table,
- column,
- chunk,
- keyvalues,
- retcols,
- )
-
- results.extend(rows)
-
- defer.returnValue(results)
-
- @classmethod
- def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Filters rows by if value of `column` is in `iterable`.
-
- Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
- """
- if not iterable:
- return []
-
- sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
-
- clauses = []
- values = []
- clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
- values.extend(iterable)
-
- for key, value in iteritems(keyvalues):
- clauses.append("%s = ?" % (key,))
- values.append(value)
-
- if clauses:
- sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
-
- txn.execute(sql, values)
- return cls.cursor_to_dict(txn)
-
- def _simple_update(self, table, keyvalues, updatevalues, desc):
- return self.runInteraction(
- desc, self._simple_update_txn, table, keyvalues, updatevalues
- )
-
- @staticmethod
- def _simple_update_txn(txn, table, keyvalues, updatevalues):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
- else:
- where = ""
-
- update_sql = "UPDATE %s SET %s %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- where,
- )
-
- txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
-
- return txn.rowcount
-
- def _simple_update_one(
- self, table, keyvalues, updatevalues, desc="_simple_update_one"
- ):
- """Executes an UPDATE query on the named table, setting new values for
- columns in a row matching the key values.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- updatevalues : dict giving column names and values to update
- retcols : optional list of column names to return
-
- If present, retcols gives a list of column names on which to perform
- a SELECT statement *before* performing the UPDATE statement. The values
- of these will be returned in a dict.
-
- These are performed within the same transaction, allowing an atomic
- get-and-set. This can be used to implement compare-and-set by putting
- the update column in the 'keyvalues' dict as well.
- """
- return self.runInteraction(
- desc, self._simple_update_one_txn, table, keyvalues, updatevalues
- )
-
- @classmethod
- def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
- rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
-
- if rowcount == 0:
- raise StoreError(404, "No row found (%s)" % (table,))
- if rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
-
- @staticmethod
- def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
- select_sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
-
- txn.execute(select_sql, list(keyvalues.values()))
- row = txn.fetchone()
-
- if not row:
- if allow_none:
- return None
- raise StoreError(404, "No row found (%s)" % (table,))
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
-
- return dict(zip(retcols, row))
-
- def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
- """Executes a DELETE query on the named table, expecting to delete a
- single row.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- """
- return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
-
- @staticmethod
- def _simple_delete_one_txn(txn, table, keyvalues):
- """Executes a DELETE query on the named table, expecting to delete a
- single row.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- """
- sql = "DELETE FROM %s WHERE %s" % (
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
-
- txn.execute(sql, list(keyvalues.values()))
- if txn.rowcount == 0:
- raise StoreError(404, "No row found (%s)" % (table,))
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
-
- def _simple_delete(self, table, keyvalues, desc):
- return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
-
- @staticmethod
- def _simple_delete_txn(txn, table, keyvalues):
- sql = "DELETE FROM %s WHERE %s" % (
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
-
- txn.execute(sql, list(keyvalues.values()))
- return txn.rowcount
-
- def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
- return self.runInteraction(
- desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
- )
-
- @staticmethod
- def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
- """Executes a DELETE query on the named table.
-
- Filters rows by if value of `column` is in `iterable`.
-
- Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
-
- Returns:
- int: Number rows deleted
- """
- if not iterable:
- return 0
-
- sql = "DELETE FROM %s" % table
-
- clauses = []
- values = []
- clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
- values.extend(iterable)
-
- for key, value in iteritems(keyvalues):
- clauses.append("%s = ?" % (key,))
- values.append(value)
-
- if clauses:
- sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
- txn.execute(sql, values)
-
- return txn.rowcount
-
- def _get_cache_dict(
- self, db_conn, table, entity_column, stream_column, max_value, limit=100000
- ):
- # Fetch a mapping of room_id -> max stream position for "recent" rooms.
- # It doesn't really matter how many we get, the StreamChangeCache will
- # do the right thing to ensure it respects the max size of cache.
- sql = (
- "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
- " WHERE %(stream)s > ? - %(limit)s"
- " GROUP BY %(entity)s"
- ) % {
- "table": table,
- "entity": entity_column,
- "stream": stream_column,
- "limit": limit,
- }
-
- sql = self.database_engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
- txn.execute(sql, (int(max_value),))
-
- cache = {row[0]: int(row[1]) for row in txn}
-
- txn.close()
-
- if cache:
- min_val = min(itervalues(cache))
- else:
- min_val = max_value
-
- return cache, min_val
-
- def _invalidate_cache_and_stream(self, txn, cache_func, keys):
- """Invalidates the cache and adds it to the cache stream so slaves
- will know to invalidate their caches.
-
- This should only be used to invalidate caches where slaves won't
- otherwise know from other replication streams that the cache should
- be invalidated.
- """
- txn.call_after(cache_func.invalidate, keys)
- self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
-
- def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
- """Special case invalidation of caches based on current state.
-
- We special case this so that we can batch the cache invalidations into a
- single replication poke.
-
- Args:
- txn
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have changed
- """
- txn.call_after(self._invalidate_state_caches, room_id, members_changed)
-
- # We need to be careful that the size of the `members_changed` list
- # isn't so large that it causes problems sending over replication, so we
- # send them in chunks.
- # Max line length is 16K, and max user ID length is 255, so 50 should
- # be safe.
- for chunk in batch_iter(members_changed, 50):
- keys = itertools.chain([room_id], chunk)
- self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
-
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
@@ -1399,7 +56,7 @@ class SQLBaseStore(object):
members_changed (iterable[str]): The user_ids of members that have
changed
"""
- for host in set(get_domain_from_id(u) for u in members_changed):
+ for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
@@ -1407,242 +64,29 @@ class SQLBaseStore(object):
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
- def _attempt_to_invalidate_cache(self, cache_name, key):
+ def _attempt_to_invalidate_cache(
+ self, cache_name: str, key: Optional[Collection[Any]]
+ ):
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
Args:
- cache_name (str)
- key (tuple)
+ cache_name
+ key: Entry to invalidate. If None then invalidates the entire
+ cache.
"""
+
try:
- getattr(self, cache_name).invalidate(key)
+ if key is None:
+ getattr(self, cache_name).invalidate_all()
+ else:
+ getattr(self, cache_name).invalidate(tuple(key))
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
pass
- def _send_invalidation_to_replication(self, txn, cache_name, keys):
- """Notifies replication that given cache has been invalidated.
-
- Note that this does *not* invalidate the cache locally.
-
- Args:
- txn
- cache_name (str)
- keys (iterable[str])
- """
-
- if isinstance(self.database_engine, PostgresEngine):
- # get_next() returns a context manager which is designed to wrap
- # the transaction. However, we want to only get an ID when we want
- # to use it, here, so we need to call __enter__ manually, and have
- # __exit__ called after the transaction finishes.
- ctx = self._cache_id_gen.get_next()
- stream_id = ctx.__enter__()
- txn.call_on_exception(ctx.__exit__, None, None, None)
- txn.call_after(ctx.__exit__, None, None, None)
- txn.call_after(self.hs.get_notifier().on_new_replication_data)
-
- self._simple_insert_txn(
- txn,
- table="cache_invalidation_stream",
- values={
- "stream_id": stream_id,
- "cache_func": cache_name,
- "keys": list(keys),
- "invalidation_ts": self.clock.time_msec(),
- },
- )
-
- def get_all_updated_caches(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_caches_txn(txn):
- # We purposefully don't bound by the current token, as we want to
- # send across cache invalidations as quickly as possible. Cache
- # invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
- return txn.fetchall()
-
- return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
-
- def get_cache_stream_token(self):
- if self._cache_id_gen:
- return self._cache_id_gen.get_current_token()
- else:
- return 0
-
- def _simple_select_list_paginate(
- self,
- table,
- keyvalues,
- orderby,
- start,
- limit,
- retcols,
- order_direction="ASC",
- desc="_simple_select_list_paginate",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc,
- self._simple_select_list_paginate_txn,
- table,
- keyvalues,
- orderby,
- start,
- limit,
- retcols,
- order_direction=order_direction,
- )
-
- @classmethod
- def _simple_select_list_paginate_txn(
- cls,
- txn,
- table,
- keyvalues,
- orderby,
- start,
- limit,
- retcols,
- order_direction="ASC",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
-
- Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- if order_direction not in ["ASC", "DESC"]:
- raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
-
- if keyvalues:
- where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
- else:
- where_clause = ""
-
- sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
- ", ".join(retcols),
- table,
- where_clause,
- orderby,
- order_direction,
- )
- txn.execute(sql, list(keyvalues.values()) + [limit, start])
-
- return cls.cursor_to_dict(txn)
-
- def get_user_count_txn(self, txn):
- """Get a total number of registered users in the users list.
-
- Args:
- txn : Transaction object
- Returns:
- int : number of users
- """
- sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
- txn.execute(sql_count)
- return txn.fetchone()[0]
-
- def _simple_search_list(
- self, table, term, col, retcols, desc="_simple_search_list"
- ):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
- """
-
- return self.runInteraction(
- desc, self._simple_search_list_txn, table, term, col, retcols
- )
-
- @classmethod
- def _simple_search_list_txn(cls, txn, table, term, col, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- txn : Transaction object
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
- """
- if term:
- sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
- termvalues = ["%%" + term + "%%"]
- txn.execute(sql, termvalues)
- else:
- return 0
-
- return cls.cursor_to_dict(txn)
-
- @property
- def database_engine_name(self):
- return self.database_engine.module.__name__
-
- def get_server_version(self):
- """Returns a string describing the server version number"""
- return self.database_engine.server_version
-
-
-class _RollbackButIsFineException(Exception):
- """ This exception is used to rollback a transaction without implying
- something went wrong.
- """
-
- pass
-
def db_to_json(db_content):
"""
@@ -1664,7 +108,7 @@ def db_to_json(db_content):
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):
- db_content = db_content.decode('utf8')
+ db_content = db_content.decode("utf8")
try:
return json.loads(db_content)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index b8b8273f73..eb1a7e5002 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Optional
from canonicaljson import json
@@ -22,7 +23,6 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines
-from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -74,7 +74,7 @@ class BackgroundUpdatePerformance(object):
return float(self.total_item_count) / float(self.total_duration_ms)
-class BackgroundUpdateStore(SQLBaseStore):
+class BackgroundUpdater(object):
""" Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
@@ -86,24 +86,26 @@ class BackgroundUpdateStore(SQLBaseStore):
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
- def __init__(self, db_conn, hs):
- super(BackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, hs, database):
+ self._clock = hs.get_clock()
+ self.db = database
+
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
self._all_done = False
def start_doing_background_updates(self):
- run_as_background_process("background_updates", self._run_background_updates)
+ run_as_background_process("background_updates", self.run_background_updates)
- @defer.inlineCallbacks
- def _run_background_updates(self):
+ async def run_background_updates(self, sleep=True):
logger.info("Starting background schema updates")
while True:
- yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
+ if sleep:
+ await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
try:
- result = yield self.do_next_background_update(
+ result = await self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except Exception:
@@ -115,7 +117,7 @@ class BackgroundUpdateStore(SQLBaseStore):
" Unscheduling background update task."
)
self._all_done = True
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def has_completed_background_updates(self):
@@ -127,63 +129,85 @@ class BackgroundUpdateStore(SQLBaseStore):
# if we've previously determined that there is nothing left to do, that
# is easy
if self._all_done:
- defer.returnValue(True)
+ return True
# obviously, if we have things in our queue, we're not done.
if self._background_update_queue:
- defer.returnValue(False)
+ return False
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
- updates = yield self._simple_select_onecol(
+ updates = yield self.db.simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
- desc="check_background_updates",
+ desc="has_completed_background_updates",
)
if not updates:
self._all_done = True
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
- @defer.inlineCallbacks
- def do_next_background_update(self, desired_duration_ms):
+ async def has_completed_background_update(self, update_name) -> bool:
+ """Check if the given background update has finished running.
+ """
+
+ if self._all_done:
+ return True
+
+ if update_name in self._background_update_queue:
+ return False
+
+ update_exists = await self.db.simple_select_one_onecol(
+ "background_updates",
+ keyvalues={"update_name": update_name},
+ retcol="1",
+ desc="has_completed_background_update",
+ allow_none=True,
+ )
+
+ return not update_exists
+
+ async def do_next_background_update(
+ self, desired_duration_ms: float
+ ) -> Optional[int]:
"""Does some amount of work on the next queued background update
+ Returns once some amount of work is done.
+
Args:
desired_duration_ms(float): How long we want to spend
updating.
Returns:
- A deferred that completes once some amount of work is done.
- The deferred will have a value of None if there is currently
- no more work to do.
+ None if there is no more work to do, otherwise an int
"""
if not self._background_update_queue:
- updates = yield self._simple_select_list(
+ updates = await self.db.simple_select_list(
"background_updates",
keyvalues=None,
retcols=("update_name", "depends_on"),
)
- in_flight = set(update["update_name"] for update in updates)
+ in_flight = {update["update_name"] for update in updates}
for update in updates:
if update["depends_on"] not in in_flight:
- self._background_update_queue.append(update['update_name'])
+ self._background_update_queue.append(update["update_name"])
if not self._background_update_queue:
# no work left to do
- defer.returnValue(None)
+ return None
# pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
- res = yield self._do_background_update(update_name, desired_duration_ms)
- defer.returnValue(res)
+ res = await self._do_background_update(update_name, desired_duration_ms)
+ return res
- @defer.inlineCallbacks
- def _do_background_update(self, update_name, desired_duration_ms):
+ async def _do_background_update(
+ self, update_name: str, desired_duration_ms: float
+ ) -> int:
logger.info("Starting update batch on background update '%s'", update_name)
update_handler = self._background_update_handlers[update_name]
@@ -203,7 +227,7 @@ class BackgroundUpdateStore(SQLBaseStore):
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
- progress_json = yield self._simple_select_one_onecol(
+ progress_json = await self.db.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json",
@@ -212,13 +236,13 @@ class BackgroundUpdateStore(SQLBaseStore):
progress = json.loads(progress_json)
time_start = self._clock.time_msec()
- items_updated = yield update_handler(progress, batch_size)
+ items_updated = await update_handler(progress, batch_size)
time_stop = self._clock.time_msec()
duration_ms = time_stop - time_start
logger.info(
- "Updating %r. Updated %r items in %rms."
+ "Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
update_name,
items_updated,
@@ -231,7 +255,7 @@ class BackgroundUpdateStore(SQLBaseStore):
performance.update(items_updated, duration_ms)
- defer.returnValue(len(self._background_update_performance))
+ return len(self._background_update_performance)
def register_background_update_handler(self, update_name, update_handler):
"""Register a handler for doing a background update.
@@ -241,7 +265,9 @@ class BackgroundUpdateStore(SQLBaseStore):
* A dict of the current progress
* An integer count of the number of items to update in this batch.
- The handler should return a deferred integer count of items updated.
+ The handler should return a deferred or coroutine which returns an integer count
+ of items updated.
+
The handler is responsible for updating the progress of the update.
Args:
@@ -266,7 +292,7 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
- defer.returnValue(1)
+ return 1
self.register_background_update_handler(update_name, noop_update)
@@ -357,7 +383,7 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.debug("[SQL] %s", sql)
c.execute(sql)
- if isinstance(self.database_engine, engines.PostgresEngine):
+ if isinstance(self.db.engine, engines.PostgresEngine):
runner = create_index_psql
elif psql_only:
runner = None
@@ -368,9 +394,9 @@ class BackgroundUpdateStore(SQLBaseStore):
def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
- yield self.runWithConnection(runner)
+ yield self.db.runWithConnection(runner)
yield self._end_background_update(update_name)
- defer.returnValue(1)
+ return 1
self.register_background_update_handler(update_name, updater)
@@ -390,7 +416,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue = []
progress_json = json.dumps(progress)
- return self._simple_insert(
+ return self.db.simple_insert(
"background_updates",
{"update_name": update_name, "progress_json": progress_json},
)
@@ -406,10 +432,25 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue = [
name for name in self._background_update_queue if name != update_name
]
- return self._simple_delete_one(
+ return self.db.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
+ def _background_update_progress(self, update_name: str, progress: dict):
+ """Update the progress of a background update
+
+ Args:
+ update_name: The name of the background update task
+ progress: The progress of the update.
+ """
+
+ return self.db.runInteraction(
+ "background_update_progress",
+ self._background_update_progress_txn,
+ update_name,
+ progress,
+ )
+
def _background_update_progress_txn(self, txn, update_name, progress):
"""Update the progress of a background update
@@ -421,7 +462,7 @@ class BackgroundUpdateStore(SQLBaseStore):
progress_json = json.dumps(progress)
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
new file mode 100644
index 0000000000..e1d03429ca
--- /dev/null
+++ b/synapse/storage/data_stores/__init__.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage.data_stores.state import StateGroupDataStore
+from synapse.storage.database import Database, make_conn
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+
+logger = logging.getLogger(__name__)
+
+
+class DataStores(object):
+ """The various data stores.
+
+ These are low level interfaces to physical databases.
+
+ Attributes:
+ main (DataStore)
+ """
+
+ def __init__(self, main_store_class, hs):
+ # Note we pass in the main store class here as workers use a different main
+ # store.
+
+ self.databases = []
+ self.main = None
+ self.state = None
+
+ for database_config in hs.config.database.databases:
+ db_name = database_config.name
+ engine = create_engine(database_config.config)
+
+ with make_conn(database_config, engine) as db_conn:
+ logger.info("Preparing database %r...", db_name)
+
+ engine.check_database(db_conn)
+ prepare_database(
+ db_conn, engine, hs.config, data_stores=database_config.data_stores,
+ )
+
+ database = Database(hs, database_config, engine)
+
+ if "main" in database_config.data_stores:
+ logger.info("Starting 'main' data store")
+
+ # Sanity check we don't try and configure the main store on
+ # multiple databases.
+ if self.main:
+ raise Exception("'main' data store already configured")
+
+ self.main = main_store_class(database, db_conn, hs)
+
+ if "state" in database_config.data_stores:
+ logger.info("Starting 'state' data store")
+
+ # Sanity check we don't try and configure the state store on
+ # multiple databases.
+ if self.state:
+ raise Exception("'state' data store already configured")
+
+ self.state = StateGroupDataStore(database, db_conn, hs)
+
+ db_conn.commit()
+
+ self.databases.append(database)
+
+ logger.info("Database %r prepared", db_name)
+
+ # Sanity check that we have actually configured all the required stores.
+ if not self.main:
+ raise Exception("No 'main' data store configured")
+
+ if not self.state:
+ raise Exception("No 'main' data store configured")
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
new file mode 100644
index 0000000000..acca079f23
--- /dev/null
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -0,0 +1,583 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import calendar
+import logging
+import time
+
+from synapse.api.constants import PresenceState
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import (
+ ChainedIdGenerator,
+ IdGenerator,
+ StreamIdGenerator,
+)
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from .account_data import AccountDataStore
+from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
+from .cache import CacheInvalidationStore
+from .client_ips import ClientIpStore
+from .deviceinbox import DeviceInboxStore
+from .devices import DeviceStore
+from .directory import DirectoryStore
+from .e2e_room_keys import EndToEndRoomKeyStore
+from .end_to_end_keys import EndToEndKeyStore
+from .event_federation import EventFederationStore
+from .event_push_actions import EventPushActionsStore
+from .events import EventsStore
+from .events_bg_updates import EventsBackgroundUpdatesStore
+from .filtering import FilteringStore
+from .group_server import GroupServerStore
+from .keys import KeyStore
+from .media_repository import MediaRepositoryStore
+from .monthly_active_users import MonthlyActiveUsersStore
+from .openid import OpenIdStore
+from .presence import PresenceStore, UserPresenceState
+from .profile import ProfileStore
+from .push_rule import PushRuleStore
+from .pusher import PusherStore
+from .receipts import ReceiptsStore
+from .registration import RegistrationStore
+from .rejections import RejectionsStore
+from .relations import RelationsStore
+from .room import RoomStore
+from .roommember import RoomMemberStore
+from .search import SearchStore
+from .signatures import SignatureStore
+from .state import StateStore
+from .stats import StatsStore
+from .stream import StreamStore
+from .tags import TagsStore
+from .transactions import TransactionStore
+from .user_directory import UserDirectoryStore
+from .user_erasure_store import UserErasureStore
+
+logger = logging.getLogger(__name__)
+
+
+class DataStore(
+ EventsBackgroundUpdatesStore,
+ RoomMemberStore,
+ RoomStore,
+ RegistrationStore,
+ StreamStore,
+ ProfileStore,
+ PresenceStore,
+ TransactionStore,
+ DirectoryStore,
+ KeyStore,
+ StateStore,
+ SignatureStore,
+ ApplicationServiceStore,
+ EventsStore,
+ EventFederationStore,
+ MediaRepositoryStore,
+ RejectionsStore,
+ FilteringStore,
+ PusherStore,
+ PushRuleStore,
+ ApplicationServiceTransactionStore,
+ ReceiptsStore,
+ EndToEndKeyStore,
+ EndToEndRoomKeyStore,
+ SearchStore,
+ TagsStore,
+ AccountDataStore,
+ EventPushActionsStore,
+ OpenIdStore,
+ ClientIpStore,
+ DeviceStore,
+ DeviceInboxStore,
+ UserDirectoryStore,
+ GroupServerStore,
+ UserErasureStore,
+ MonthlyActiveUsersStore,
+ StatsStore,
+ RelationsStore,
+ CacheInvalidationStore,
+):
+ def __init__(self, database: Database, db_conn, hs):
+ self.hs = hs
+ self._clock = hs.get_clock()
+ self.database_engine = database.engine
+
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ extra_tables=[("local_invites", "stream_id")],
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ )
+ self._presence_id_gen = StreamIdGenerator(
+ db_conn, "presence_stream", "stream_id"
+ )
+ self._device_inbox_id_gen = StreamIdGenerator(
+ db_conn, "device_max_stream_id", "stream_id"
+ )
+ self._public_room_id_gen = StreamIdGenerator(
+ db_conn, "public_room_list_stream", "stream_id"
+ )
+ self._device_list_id_gen = StreamIdGenerator(
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[("user_signature_stream", "stream_id")],
+ )
+ self._cross_signing_id_gen = StreamIdGenerator(
+ db_conn, "e2e_cross_signing_keys", "stream_id"
+ )
+
+ self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+ self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
+ self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+ self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+ self._push_rules_stream_id_gen = ChainedIdGenerator(
+ self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+ )
+ self._pushers_id_gen = StreamIdGenerator(
+ db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+ )
+ self._group_updates_id_gen = StreamIdGenerator(
+ db_conn, "local_group_updates", "stream_id"
+ )
+
+ if isinstance(self.database_engine, PostgresEngine):
+ self._cache_id_gen = StreamIdGenerator(
+ db_conn, "cache_invalidation_stream", "stream_id"
+ )
+ else:
+ self._cache_id_gen = None
+
+ super(DataStore, self).__init__(database, db_conn, hs)
+
+ self._presence_on_startup = self._get_active_presence(db_conn)
+
+ presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
+ db_conn,
+ "presence_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self._presence_id_gen.get_current_token(),
+ )
+ self.presence_stream_cache = StreamChangeCache(
+ "PresenceStreamChangeCache",
+ min_presence_val,
+ prefilled_cache=presence_cache_prefill,
+ )
+
+ max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
+ device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
+ db_conn,
+ "device_inbox",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=max_device_inbox_id,
+ limit=1000,
+ )
+ self._device_inbox_stream_cache = StreamChangeCache(
+ "DeviceInboxStreamChangeCache",
+ min_device_inbox_id,
+ prefilled_cache=device_inbox_prefill,
+ )
+ # The federation outbox and the local device inbox uses the same
+ # stream_id generator.
+ device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
+ db_conn,
+ "device_federation_outbox",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=max_device_inbox_id,
+ limit=1000,
+ )
+ self._device_federation_outbox_stream_cache = StreamChangeCache(
+ "DeviceFederationOutboxStreamChangeCache",
+ min_device_outbox_id,
+ prefilled_cache=device_outbox_prefill,
+ )
+
+ device_list_max = self._device_list_id_gen.get_current_token()
+ self._device_list_stream_cache = StreamChangeCache(
+ "DeviceListStreamChangeCache", device_list_max
+ )
+ self._user_signature_stream_cache = StreamChangeCache(
+ "UserSignatureStreamChangeCache", device_list_max
+ )
+ self._device_list_federation_stream_cache = StreamChangeCache(
+ "DeviceListFederationStreamChangeCache", device_list_max
+ )
+
+ events_max = self._stream_id_gen.get_current_token()
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+ db_conn,
+ "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache = StreamChangeCache(
+ "_curr_state_delta_stream_cache",
+ min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
+ )
+
+ _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
+ db_conn,
+ "local_group_updates",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self._group_updates_id_gen.get_current_token(),
+ limit=1000,
+ )
+ self._group_updates_stream_cache = StreamChangeCache(
+ "_group_updates_stream_cache",
+ min_group_updates_id,
+ prefilled_cache=_group_updates_prefill,
+ )
+
+ self._stream_order_on_start = self.get_room_max_stream_ordering()
+ self._min_stream_order_on_start = self.get_room_min_stream_ordering()
+
+ # Used in _generate_user_daily_visits to keep track of progress
+ self._last_user_visit_update = self._get_start_of_day()
+
+ def take_presence_startup_info(self):
+ active_on_startup = self._presence_on_startup
+ self._presence_on_startup = None
+ return active_on_startup
+
+ def _get_active_presence(self, db_conn):
+ """Fetch non-offline presence from the database so that we can register
+ the appropriate time outs.
+ """
+
+ sql = (
+ "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
+ " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
+ " WHERE state != ?"
+ )
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (PresenceState.OFFLINE,))
+ rows = self.db.cursor_to_dict(txn)
+ txn.close()
+
+ for row in rows:
+ row["currently_active"] = bool(row["currently_active"])
+
+ return [UserPresenceState(**row) for row in rows]
+
+ def count_daily_users(self):
+ """
+ Counts the number of users who used this homeserver in the last 24 hours.
+ """
+ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+ return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
+
+ def count_monthly_users(self):
+ """
+ Counts the number of users who used this homeserver in the last 30 days.
+ Note this method is intended for phonehome metrics only and is different
+ from the mau figure in synapse.storage.monthly_active_users which,
+ amongst other things, includes a 3 day grace period before a user counts.
+ """
+ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ return self.db.runInteraction(
+ "count_monthly_users", self._count_users, thirty_days_ago
+ )
+
+ def _count_users(self, txn, time_from):
+ """
+ Returns number of users seen in the past time_from period
+ """
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT user_id FROM user_ips
+ WHERE last_seen > ?
+ GROUP BY user_id
+ ) u
+ """
+ txn.execute(sql, (time_from,))
+ (count,) = txn.fetchone()
+ return count
+
+ def count_r30_users(self):
+ """
+ Counts the number of 30 day retained users, defined as:-
+ * Users who have created their accounts more than 30 days ago
+ * Where last seen at most 30 days ago
+ * Where account creation and last_seen are > 30 days apart
+
+ Returns counts globaly for a given user as well as breaking
+ by platform
+ """
+
+ def _count_r30_users(txn):
+ thirty_days_in_secs = 86400 * 30
+ now = int(self._clock.time())
+ thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+ sql = """
+ SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT
+ users.name, platform, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen,
+ CASE
+ WHEN user_agent LIKE '%%Android%%' THEN 'android'
+ WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+ WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+ WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+ WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+ ELSE 'unknown'
+ END
+ AS platform
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND users.appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, platform, users.creation_ts
+ ) u GROUP BY platform
+ """
+
+ results = {}
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ for row in txn:
+ if row[0] == "unknown":
+ pass
+ results[row[0]] = row[1]
+
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT users.name, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, users.creation_ts
+ ) u
+ """
+
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ (count,) = txn.fetchone()
+ results["all"] = count
+
+ return results
+
+ return self.db.runInteraction("count_r30_users", _count_r30_users)
+
+ def _get_start_of_day(self):
+ """
+ Returns millisecond unixtime for start of UTC day.
+ """
+ now = time.gmtime()
+ today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+ return today_start * 1000
+
+ def generate_user_daily_visits(self):
+ """
+ Generates daily visit data for use in cohort/ retention analysis
+ """
+
+ def _generate_user_daily_visits(txn):
+ logger.info("Calling _generate_user_daily_visits")
+ today_start = self._get_start_of_day()
+ a_day_in_milliseconds = 24 * 60 * 60 * 1000
+ now = self.clock.time_msec()
+
+ sql = """
+ INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+ SELECT u.user_id, u.device_id, ?
+ FROM user_ips AS u
+ LEFT JOIN (
+ SELECT user_id, device_id, timestamp FROM user_daily_visits
+ WHERE timestamp = ?
+ ) udv
+ ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+ INNER JOIN users ON users.name=u.user_id
+ WHERE last_seen > ? AND last_seen <= ?
+ AND udv.timestamp IS NULL AND users.is_guest=0
+ AND users.appservice_id IS NULL
+ GROUP BY u.user_id, u.device_id
+ """
+
+ # This means that the day has rolled over but there could still
+ # be entries from the previous day. There is an edge case
+ # where if the user logs in at 23:59 and overwrites their
+ # last_seen at 00:01 then they will not be counted in the
+ # previous day's stats - it is important that the query is run
+ # often to minimise this case.
+ if today_start > self._last_user_visit_update:
+ yesterday_start = today_start - a_day_in_milliseconds
+ txn.execute(
+ sql,
+ (
+ yesterday_start,
+ yesterday_start,
+ self._last_user_visit_update,
+ today_start,
+ ),
+ )
+ self._last_user_visit_update = today_start
+
+ txn.execute(
+ sql, (today_start, today_start, self._last_user_visit_update, now)
+ )
+ # Update _last_user_visit_update to now. The reason to do this
+ # rather just clamping to the beginning of the day is to limit
+ # the size of the join - meaning that the query can be run more
+ # frequently
+ self._last_user_visit_update = now
+
+ return self.db.runInteraction(
+ "generate_user_daily_visits", _generate_user_daily_visits
+ )
+
+ def get_users(self):
+ """Function to retrieve a list of users in users table.
+
+ Args:
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.db.simple_select_list(
+ table="users",
+ keyvalues={},
+ retcols=[
+ "name",
+ "password_hash",
+ "is_guest",
+ "admin",
+ "user_type",
+ "deactivated",
+ ],
+ desc="get_users",
+ )
+
+ def get_users_paginate(
+ self, start, limit, name=None, guests=True, deactivated=False
+ ):
+ """Function to retrieve a paginated list of users from
+ users list. This will return a json list of users.
+
+ Args:
+ start (int): start number to begin the query from
+ limit (int): number of rows to retrieve
+ name (string): filter for user names
+ guests (bool): whether to in include guest users
+ deactivated (bool): whether to include deactivated users
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ name_filter = {}
+ if name:
+ name_filter["name"] = "%" + name + "%"
+
+ attr_filter = {}
+ if not guests:
+ attr_filter["is_guest"] = 0
+ if not deactivated:
+ attr_filter["deactivated"] = 0
+
+ return self.db.simple_select_list_paginate(
+ desc="get_users_paginate",
+ table="users",
+ orderby="name",
+ start=start,
+ limit=limit,
+ filters=name_filter,
+ keyvalues=attr_filter,
+ retcols=[
+ "name",
+ "password_hash",
+ "is_guest",
+ "admin",
+ "user_type",
+ "deactivated",
+ ],
+ )
+
+ def search_users(self, term):
+ """Function to search users list for one or more users with
+ the matched term.
+
+ Args:
+ term (str): search term
+ col (str): column to query term should be matched to
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.db.simple_search_list(
+ table="users",
+ term=term,
+ col="name",
+ retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
+ desc="search_users",
+ )
+
+
+def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
+ """Called before upgrading an existing database to check that it is broadly sane
+ compared with the configuration.
+ """
+ domain = config.server_name
+
+ sql = database_engine.convert_param_style(
+ "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+ )
+ pat = "%:" + domain
+ cur.execute(sql, (pat,))
+ num_not_matching = cur.fetchall()[0][0]
+ if num_not_matching == 0:
+ return
+
+ raise Exception(
+ "Found users in database not native to %s!\n"
+ "You cannot changed a synapse server_name after it's been configured"
+ % (domain,)
+ )
+
+
+__all__ = ["DataStore", "check_database_before_upgrade"]
diff --git a/synapse/storage/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 8394389073..46b494b334 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,6 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
)
- super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+ super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
@@ -67,7 +68,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_user_txn(txn):
- rows = self._simple_select_list_txn(
+ rows = self.db.simple_select_list_txn(
txn,
"account_data",
{"user_id": user_id},
@@ -78,7 +79,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
- rows = self._simple_select_list_txn(
+ rows = self.db.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id},
@@ -90,9 +91,9 @@ class AccountDataWorkerStore(SQLBaseStore):
room_data = by_room.setdefault(row["room_id"], {})
room_data[row["account_data_type"]] = json.loads(row["content"])
- return (global_account_data, by_room)
+ return global_account_data, by_room
- return self.runInteraction(
+ return self.db.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@@ -102,7 +103,7 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns:
Deferred: A dict
"""
- result = yield self._simple_select_one_onecol(
+ result = yield self.db.simple_select_one_onecol(
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
@@ -111,9 +112,9 @@ class AccountDataWorkerStore(SQLBaseStore):
)
if result:
- defer.returnValue(json.loads(result))
+ return json.loads(result)
else:
- defer.returnValue(None)
+ return None
@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
@@ -127,7 +128,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_txn(txn):
- rows = self._simple_select_list_txn(
+ rows = self.db.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
@@ -138,7 +139,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
- return self.runInteraction(
+ return self.db.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@@ -156,7 +157,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_and_type_txn(txn):
- content_json = self._simple_select_one_onecol_txn(
+ content_json = self.db.simple_select_one_onecol_txn(
txn,
table="room_account_data",
keyvalues={
@@ -170,7 +171,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return json.loads(content_json) if content_json else None
- return self.runInteraction(
+ return self.db.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@@ -184,14 +185,14 @@ class AccountDataWorkerStore(SQLBaseStore):
current_id(int): The position to fetch up to.
Returns:
A deferred pair of lists of tuples of stream_id int, user_id string,
- room_id string, type string, and content string.
+ room_id string, and type string.
"""
if last_room_id == current_id and last_global_id == current_id:
return defer.succeed(([], []))
def get_updated_account_data_txn(txn):
sql = (
- "SELECT stream_id, user_id, account_data_type, content"
+ "SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
@@ -199,15 +200,15 @@ class AccountDataWorkerStore(SQLBaseStore):
global_results = txn.fetchall()
sql = (
- "SELECT stream_id, user_id, room_id, account_data_type, content"
+ "SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_room_id, current_id, limit))
room_results = txn.fetchall()
- return (global_results, room_results)
+ return global_results, room_results
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn
)
@@ -244,15 +245,15 @@ class AccountDataWorkerStore(SQLBaseStore):
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2])
- return (global_account_data, account_data_by_room)
+ return global_account_data, account_data_by_room
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
- return ({}, {})
+ return defer.succeed(({}, {}))
- return self.runInteraction(
+ return self.db.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@@ -264,20 +265,18 @@ class AccountDataWorkerStore(SQLBaseStore):
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
- defer.returnValue(False)
+ return False
- defer.returnValue(
- ignored_user_id in ignored_account_data.get("ignored_users", {})
- )
+ return ignored_user_id in ignored_account_data.get("ignored_users", {})
class AccountDataStore(AccountDataWorkerStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
- super(AccountDataStore, self).__init__(db_conn, hs)
+ super(AccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
@@ -302,9 +301,9 @@ class AccountDataStore(AccountDataWorkerStore):
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
- # on (user_id, room_id, account_data_type) so _simple_upsert will
+ # on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
- yield self._simple_upsert(
+ yield self.db.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
@@ -332,7 +331,7 @@ class AccountDataStore(AccountDataWorkerStore):
)
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def add_account_data_for_user(self, user_id, account_data_type, content):
@@ -348,9 +347,9 @@ class AccountDataStore(AccountDataWorkerStore):
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
- # (user_id, account_data_type) so _simple_upsert will retry if
+ # (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
- yield self._simple_upsert(
+ yield self.db.simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@@ -373,7 +372,7 @@ class AccountDataStore(AccountDataWorkerStore):
)
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
def _update_max_stream_id(self, next_id):
"""Update the max stream_id
@@ -390,4 +389,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
- return self.runInteraction("update_account_data_max_stream_id", _update)
+ return self.db.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 9d9b28de13..9c52aa5340 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -22,9 +22,9 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage.events_worker import EventsWorkerStore
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
logger = logging.getLogger(__name__)
@@ -49,13 +49,13 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
- super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+ super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
def get_app_services(self):
return self.services_cache
@@ -134,8 +134,8 @@ class ApplicationServiceTransactionWorkerStore(
A Deferred which resolves to a list of ApplicationServices, which
may be empty.
"""
- results = yield self._simple_select_list(
- "application_services_state", dict(state=state), ["as_id"]
+ results = yield self.db.simple_select_list(
+ "application_services_state", {"state": state}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@@ -145,7 +145,7 @@ class ApplicationServiceTransactionWorkerStore(
for service in as_list:
if service.id == res["as_id"]:
services.append(service)
- defer.returnValue(services)
+ return services
@defer.inlineCallbacks
def get_appservice_state(self, service):
@@ -156,17 +156,16 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves to ApplicationServiceState.
"""
- result = yield self._simple_select_one(
+ result = yield self.db.simple_select_one(
"application_services_state",
- dict(as_id=service.id),
+ {"as_id": service.id},
["state"],
allow_none=True,
desc="get_appservice_state",
)
if result:
- defer.returnValue(result.get("state"))
- return
- defer.returnValue(None)
+ return result.get("state")
+ return None
def set_appservice_state(self, service, state):
"""Set the application service state.
@@ -177,8 +176,8 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves when the state was set successfully.
"""
- return self._simple_upsert(
- "application_services_state", dict(as_id=service.id), dict(state=state)
+ return self.db.simple_upsert(
+ "application_services_state", {"as_id": service.id}, {"state": state}
)
def create_appservice_txn(self, service, events):
@@ -218,7 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
- return self.runInteraction("create_appservice_txn", _create_appservice_txn)
+ return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction.
@@ -251,19 +250,23 @@ class ApplicationServiceTransactionWorkerStore(
)
# Set current txn_id for AS to 'txn_id'
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
"application_services_state",
- dict(as_id=service.id),
- dict(last_txn=txn_id),
+ {"as_id": service.id},
+ {"last_txn": txn_id},
)
# Delete txn
- self._simple_delete_txn(
- txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
+ self.db.simple_delete_txn(
+ txn,
+ "application_services_txns",
+ {"txn_id": txn_id, "as_id": service.id},
)
- return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
+ return self.db.runInteraction(
+ "complete_appservice_txn", _complete_appservice_txn
+ )
@defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
@@ -285,7 +288,7 @@ class ApplicationServiceTransactionWorkerStore(
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
return None
@@ -293,20 +296,18 @@ class ApplicationServiceTransactionWorkerStore(
return entry
- entry = yield self.runInteraction(
+ entry = yield self.db.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)
if not entry:
- defer.returnValue(None)
+ return None
event_ids = json.loads(entry["event_ids"])
events = yield self.get_events_as_list(event_ids)
- defer.returnValue(
- AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
- )
+ return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
def _get_last_txn(self, txn, service_id):
txn.execute(
@@ -325,7 +326,7 @@ class ApplicationServiceTransactionWorkerStore(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
- return self.runInteraction(
+ return self.db.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@@ -354,13 +355,13 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.runInteraction(
+ upper_bound, event_ids = yield self.db.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
events = yield self.get_events_as_list(event_ids)
- defer.returnValue((upper_bound, events))
+ return upper_bound, events
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
new file mode 100644
index 0000000000..d4c44dcc75
--- /dev/null
+++ b/synapse/storage/data_stores/main/cache.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import itertools
+import logging
+from typing import Any, Iterable, Optional, Tuple
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.engines import PostgresEngine
+from synapse.util.iterutils import batch_iter
+
+logger = logging.getLogger(__name__)
+
+
+# This is a special cache name we use to batch multiple invalidations of caches
+# based on the current state when notifying workers over replication.
+CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
+
+
+class CacheInvalidationStore(SQLBaseStore):
+ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+ """Invalidates the cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+
+ This should only be used to invalidate caches where slaves won't
+ otherwise know from other replication streams that the cache should
+ be invalidated.
+ """
+ cache_func = getattr(self, cache_name, None)
+ if not cache_func:
+ return
+
+ cache_func.invalidate(keys)
+ await self.runInteraction(
+ "invalidate_cache_and_stream",
+ self._send_invalidation_to_replication,
+ cache_func.__name__,
+ keys,
+ )
+
+ def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ """Invalidates the cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+
+ This should only be used to invalidate caches where slaves won't
+ otherwise know from other replication streams that the cache should
+ be invalidated.
+ """
+ txn.call_after(cache_func.invalidate, keys)
+ self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
+
+ def _invalidate_all_cache_and_stream(self, txn, cache_func):
+ """Invalidates the entire cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+ """
+
+ txn.call_after(cache_func.invalidate_all)
+ self._send_invalidation_to_replication(txn, cache_func.__name__, None)
+
+ def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+ """Special case invalidation of caches based on current state.
+
+ We special case this so that we can batch the cache invalidations into a
+ single replication poke.
+
+ Args:
+ txn
+ room_id (str): Room where state changed
+ members_changed (iterable[str]): The user_ids of members that have changed
+ """
+ txn.call_after(self._invalidate_state_caches, room_id, members_changed)
+
+ if members_changed:
+ # We need to be careful that the size of the `members_changed` list
+ # isn't so large that it causes problems sending over replication, so we
+ # send them in chunks.
+ # Max line length is 16K, and max user ID length is 255, so 50 should
+ # be safe.
+ for chunk in batch_iter(members_changed, 50):
+ keys = itertools.chain([room_id], chunk)
+ self._send_invalidation_to_replication(
+ txn, CURRENT_STATE_CACHE_NAME, keys
+ )
+ else:
+ # if no members changed, we still need to invalidate the other caches.
+ self._send_invalidation_to_replication(
+ txn, CURRENT_STATE_CACHE_NAME, [room_id]
+ )
+
+ def _send_invalidation_to_replication(
+ self, txn, cache_name: str, keys: Optional[Iterable[Any]]
+ ):
+ """Notifies replication that given cache has been invalidated.
+
+ Note that this does *not* invalidate the cache locally.
+
+ Args:
+ txn
+ cache_name
+ keys: Entry to invalidate. If None will invalidate all.
+ """
+
+ if cache_name == CURRENT_STATE_CACHE_NAME and keys is None:
+ raise Exception(
+ "Can't stream invalidate all with magic current state cache"
+ )
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # get_next() returns a context manager which is designed to wrap
+ # the transaction. However, we want to only get an ID when we want
+ # to use it, here, so we need to call __enter__ manually, and have
+ # __exit__ called after the transaction finishes.
+ ctx = self._cache_id_gen.get_next()
+ stream_id = ctx.__enter__()
+ txn.call_on_exception(ctx.__exit__, None, None, None)
+ txn.call_after(ctx.__exit__, None, None, None)
+ txn.call_after(self.hs.get_notifier().on_new_replication_data)
+
+ if keys is not None:
+ keys = list(keys)
+
+ self.db.simple_insert_txn(
+ txn,
+ table="cache_invalidation_stream",
+ values={
+ "stream_id": stream_id,
+ "cache_func": cache_name,
+ "keys": keys,
+ "invalidation_ts": self.clock.time_msec(),
+ },
+ )
+
+ def get_all_updated_caches(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_caches_txn(txn):
+ # We purposefully don't bound by the current token, as we want to
+ # send across cache invalidations as quickly as possible. Cache
+ # invalidations are idempotent, so duplicates are fine.
+ sql = (
+ "SELECT stream_id, cache_func, keys, invalidation_ts"
+ " FROM cache_invalidation_stream"
+ " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_updated_caches", get_all_updated_caches_txn
+ )
+
+ def get_cache_stream_token(self):
+ if self._cache_id_gen:
+ return self._cache_id_gen.get_current_token()
+ else:
+ return 0
diff --git a/synapse/storage/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index bda68de5be..e1ccb27142 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -19,11 +19,11 @@ from six import iteritems
from twisted.internet import defer
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.util.caches import CACHE_SIZE_FACTOR
-
-from . import background_updates
-from ._base import Cache
+from synapse.util.caches.descriptors import Cache
logger = logging.getLogger(__name__)
@@ -33,46 +33,41 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
-class ClientIpStore(background_updates.BackgroundUpdateStore):
- def __init__(self, db_conn, hs):
-
- self.client_ip_last_seen = Cache(
- name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
- )
+class ClientIpBackgroundUpdateStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- super(ClientIpStore, self).__init__(db_conn, hs)
-
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"user_ips_device_index",
index_name="user_ips_device_id",
table="user_ips",
columns=["user_id", "device_id", "last_seen"],
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"user_ips_last_seen_index",
index_name="user_ips_last_seen",
table="user_ips",
columns=["user_id", "last_seen"],
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"user_ips_last_seen_only_index",
index_name="user_ips_last_seen_only",
table="user_ips",
columns=["last_seen"],
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"user_ips_analyze", self._analyze_user_ip
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"user_ips_remove_dupes", self._remove_user_ip_dupes
)
# Register a unique index
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"user_ips_device_unique_index",
index_name="user_ips_user_token_ip_unique_index",
table="user_ips",
@@ -81,18 +76,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
# Drop the old non-unique index
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
)
- # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
- self._batch_row_update = {}
-
- self._client_ip_looper = self._clock.looping_call(
- self._update_client_ips_batch, 5 * 1000
- )
- self.hs.get_reactor().addSystemEventTrigger(
- "before", "shutdown", self._update_client_ips_batch
+ # Update the last seen info in devices.
+ self.db.updates.register_background_update_handler(
+ "devices_last_seen", self._devices_last_seen_update
)
@defer.inlineCallbacks
@@ -102,9 +92,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
- yield self.runWithConnection(f)
- yield self._end_background_update("user_ips_drop_nonunique_index")
- defer.returnValue(1)
+ yield self.db.runWithConnection(f)
+ yield self.db.updates._end_background_update("user_ips_drop_nonunique_index")
+ return 1
@defer.inlineCallbacks
def _analyze_user_ip(self, progress, batch_size):
@@ -117,11 +107,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")
- yield self.runInteraction("user_ips_analyze", user_ips_analyze)
+ yield self.db.runInteraction("user_ips_analyze", user_ips_analyze)
- yield self._end_background_update("user_ips_analyze")
+ yield self.db.updates._end_background_update("user_ips_analyze")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _remove_user_ip_dupes(self, progress, batch_size):
@@ -151,7 +141,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
- end_last_seen = yield self.runInteraction(
+ end_last_seen = yield self.db.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)
@@ -282,16 +272,120 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
(user_id, access_token, ip, device_id, user_agent, last_seen),
)
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)
- yield self.runInteraction("user_ips_dups_remove", remove)
+ yield self.db.runInteraction("user_ips_dups_remove", remove)
if last:
- yield self._end_background_update("user_ips_remove_dupes")
+ yield self.db.updates._end_background_update("user_ips_remove_dupes")
+
+ return batch_size
+
+ @defer.inlineCallbacks
+ def _devices_last_seen_update(self, progress, batch_size):
+ """Background update to insert last seen info into devices table
+ """
+
+ last_user_id = progress.get("last_user_id", "")
+ last_device_id = progress.get("last_device_id", "")
+
+ def _devices_last_seen_update_txn(txn):
+ # This consists of two queries:
+ #
+ # 1. The sub-query searches for the next N devices and joins
+ # against user_ips to find the max last_seen associated with
+ # that device.
+ # 2. The outer query then joins again against user_ips on
+ # user/device/last_seen. This *should* hopefully only
+ # return one row, but if it does return more than one then
+ # we'll just end up updating the same device row multiple
+ # times, which is fine.
+
+ if self.database_engine.supports_tuple_comparison:
+ where_clause = "(user_id, device_id) > (?, ?)"
+ where_args = [last_user_id, last_device_id]
+ else:
+ # We explicitly do a `user_id >= ? AND (...)` here to ensure
+ # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
+ # makes it hard for query optimiser to tell that it can use the
+ # index on user_id
+ where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
+ where_args = [last_user_id, last_user_id, last_device_id]
+
+ sql = """
+ SELECT
+ last_seen, ip, user_agent, user_id, device_id
+ FROM (
+ SELECT
+ user_id, device_id, MAX(u.last_seen) AS last_seen
+ FROM devices
+ INNER JOIN user_ips AS u USING (user_id, device_id)
+ WHERE %(where_clause)s
+ GROUP BY user_id, device_id
+ ORDER BY user_id ASC, device_id ASC
+ LIMIT ?
+ ) c
+ INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
+ """ % {
+ "where_clause": where_clause
+ }
+ txn.execute(sql, where_args + [batch_size])
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ sql = """
+ UPDATE devices
+ SET last_seen = ?, ip = ?, user_agent = ?
+ WHERE user_id = ? AND device_id = ?
+ """
+ txn.execute_batch(sql, rows)
+
+ _, _, _, user_id, device_id = rows[-1]
+ self.db.updates._background_update_progress_txn(
+ txn,
+ "devices_last_seen",
+ {"last_user_id": user_id, "last_device_id": device_id},
+ )
+
+ return len(rows)
+
+ updated = yield self.db.runInteraction(
+ "_devices_last_seen_update", _devices_last_seen_update_txn
+ )
+
+ if not updated:
+ yield self.db.updates._end_background_update("devices_last_seen")
+
+ return updated
+
+
+class ClientIpStore(ClientIpBackgroundUpdateStore):
+ def __init__(self, database: Database, db_conn, hs):
+
+ self.client_ip_last_seen = Cache(
+ name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+ )
+
+ super(ClientIpStore, self).__init__(database, db_conn, hs)
+
+ self.user_ips_max_age = hs.config.user_ips_max_age
- defer.returnValue(batch_size)
+ # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+ self._batch_row_update = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ self.hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", self._update_client_ips_batch
+ )
+
+ if self.user_ips_max_age:
+ self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks
def insert_client_ip(
@@ -314,23 +408,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
+ @wrap_as_background_process("update_client_ips")
def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
- if not self.hs.get_db_pool().running:
+ if not self.db.is_running():
return
- def update():
- to_update = self._batch_row_update
- self._batch_row_update = {}
- return self.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
- )
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
- return run_as_background_process("update_client_ips", update)
+ return self.db.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+ )
def _update_client_ips_batch_txn(self, txn, to_update):
- if "user_ips" in self._unsafe_to_upsert_tables or (
+ if "user_ips" in self.db._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
self.database_engine.lock_table(txn, "user_ips")
@@ -339,7 +432,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="user_ips",
keyvalues={
@@ -354,6 +447,23 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
},
lock=False,
)
+
+ # Technically an access token might not be associated with
+ # a device so we need to check.
+ if device_id:
+ # this is always an update rather than an upsert: the row should
+ # already exist, and if it doesn't, that may be because it has been
+ # deleted, and we don't want to re-create it.
+ self.db.simple_update_txn(
+ txn,
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ updatevalues={
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ "ip": ip,
+ },
+ )
except Exception as e:
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)
@@ -372,19 +482,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
keys giving the column names
"""
- res = yield self.runInteraction(
- "get_last_client_ip_by_device",
- self._get_last_client_ip_by_device_txn,
- user_id,
- device_id,
- retcols=(
- "user_id",
- "access_token",
- "ip",
- "user_agent",
- "device_id",
- "last_seen",
- ),
+ keyvalues = {"user_id": user_id}
+ if device_id is not None:
+ keyvalues["device_id"] = device_id
+
+ res = yield self.db.simple_select_list(
+ table="devices",
+ keyvalues=keyvalues,
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
@@ -401,43 +506,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id": did,
"last_seen": last_seen,
}
- defer.returnValue(ret)
-
- @classmethod
- def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
- where_clauses = []
- bindings = []
- if device_id is None:
- where_clauses.append("user_id = ?")
- bindings.extend((user_id,))
- else:
- where_clauses.append("(user_id = ? AND device_id = ?)")
- bindings.extend((user_id, device_id))
-
- if not where_clauses:
- return []
-
- inner_select = (
- "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
- "WHERE %(where)s "
- "GROUP BY user_id, device_id"
- ) % {"where": " OR ".join(where_clauses)}
-
- sql = (
- "SELECT %(retcols)s FROM user_ips "
- "JOIN (%(inner_select)s) ips ON"
- " user_ips.last_seen = ips.mls AND"
- " user_ips.user_id = ips.user_id AND"
- " (user_ips.device_id = ips.device_id OR"
- " (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
- " )"
- ) % {
- "retcols": ",".join("user_ips." + c for c in retcols),
- "inner_select": inner_select,
- }
-
- txn.execute(sql, bindings)
- return cls.cursor_to_dict(txn)
+ return ret
@defer.inlineCallbacks
def get_user_ip_and_agents(self, user):
@@ -450,7 +519,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
- rows = yield self._simple_select_list(
+ rows = yield self.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
@@ -461,14 +530,56 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows
)
- defer.returnValue(
- list(
- {
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "last_seen": last_seen,
- }
- for (access_token, ip), (user_agent, last_seen) in iteritems(results)
+ return [
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in iteritems(results)
+ ]
+
+ @wrap_as_background_process("prune_old_user_ips")
+ async def _prune_old_user_ips(self):
+ """Removes entries in user IPs older than the configured period.
+ """
+
+ if self.user_ips_max_age is None:
+ # Nothing to do
+ return
+
+ if not await self.db.updates.has_completed_background_update(
+ "devices_last_seen"
+ ):
+ # Only start pruning if we have finished populating the devices
+ # last seen info.
+ return
+
+ # We do a slightly funky SQL delete to ensure we don't try and delete
+ # too much at once (as the table may be very large from before we
+ # started pruning).
+ #
+ # This works by finding the max last_seen that is less than the given
+ # time, but has no more than N rows before it, deleting all rows with
+ # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+ # returns exactly one row).
+ sql = """
+ DELETE FROM user_ips
+ WHERE last_seen <= (
+ SELECT COALESCE(MAX(last_seen), -1)
+ FROM (
+ SELECT last_seen FROM user_ips
+ WHERE last_seen <= ?
+ ORDER BY last_seen ASC
+ LIMIT 5000
+ ) AS u
)
- )
+ """
+
+ timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+ def _prune_old_user_ips_txn(txn):
+ txn.execute(sql, (timestamp,))
+
+ await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 4ea0deea4f..0613b49f4a 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -19,8 +19,9 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__)
@@ -66,12 +67,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
- return (messages, stream_pos)
+ return messages, stream_pos
- return self.runInteraction(
+ return self.db.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
)
+ @trace
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
@@ -87,12 +89,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), None
)
+
+ set_tag("last_deleted_stream_id", last_deleted_stream_id)
+
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_deleted_stream_id
)
if not has_changed:
- defer.returnValue(0)
+ log_kv({"message": "No changes in cache since last check"})
+ return 0
def delete_messages_for_device_txn(txn):
sql = (
@@ -103,10 +109,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
- count = yield self.runInteraction(
+ count = yield self.db.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
+ log_kv(
+ {"message": "deleted {} messages for device".format(count), "count": count}
+ )
+
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
@@ -115,8 +125,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
last_deleted_stream_id, up_to_stream_id
)
- defer.returnValue(count)
+ return count
+ @trace
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
):
@@ -132,16 +143,23 @@ class DeviceInboxWorkerStore(SQLBaseStore):
in the stream the messages got to.
"""
+ set_tag("destination", destination)
+ set_tag("last_stream_id", last_stream_id)
+ set_tag("current_stream_id", current_stream_id)
+ set_tag("limit", limit)
+
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed or last_stream_id == current_stream_id:
+ log_kv({"message": "No new messages in stream"})
return defer.succeed(([], current_stream_id))
if limit <= 0:
# This can happen if we run out of room for EDUs in the transaction.
return defer.succeed(([], last_stream_id))
+ @trace
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
@@ -156,14 +174,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
+ log_kv({"message": "Set stream position to current position"})
stream_pos = current_stream_id
- return (messages, stream_pos)
+ return messages, stream_pos
- return self.runInteraction(
+ return self.db.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
+ @trace
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
@@ -183,28 +203,48 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
txn.execute(sql, (destination, up_to_stream_id))
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
-class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
+class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, db_conn, hs):
- super(DeviceInboxStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
+ @defer.inlineCallbacks
+ def _background_drop_index_device_inbox(self, progress, batch_size):
+ def reindex_txn(conn):
+ txn = conn.cursor()
+ txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
+ txn.close()
+
+ yield self.db.runWithConnection(reindex_txn)
+
+ yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+
+ return 1
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
+ DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceInboxStore, self).__init__(database, db_conn, hs)
+
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
@@ -214,6 +254,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
expiry_ms=30 * 60 * 1000,
)
+ @trace
@defer.inlineCallbacks
def add_messages_to_device_inbox(
self, local_messages_by_user_then_device, remote_messages_by_destination
@@ -253,7 +294,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
@@ -263,7 +304,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
destination, stream_id
)
- defer.returnValue(self._device_inbox_id_gen.get_current_token())
+ return self._device_inbox_id_gen.get_current_token()
@defer.inlineCallbacks
def add_messages_from_remote_to_device_inbox(
@@ -273,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
- already_inserted = self._simple_select_one_txn(
+ already_inserted = self.db.simple_select_one_txn(
txn,
table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
@@ -285,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# Add an entry for this message_id so that we know we've processed
# it.
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="device_federation_inbox",
values={
@@ -303,7 +344,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,
@@ -312,7 +353,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
- defer.returnValue(stream_id)
+ return stream_id
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
@@ -326,7 +367,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
- sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
+ sql = "SELECT device_id FROM devices WHERE user_id = ?"
txn.execute(sql, (user_id,))
message_json = json.dumps(messages_by_device["*"])
for row in txn:
@@ -337,15 +378,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
else:
if not devices:
continue
- sql = (
- "SELECT device_id FROM devices"
- " WHERE user_id = ? AND device_id IN ("
- + ",".join("?" * len(devices))
- + ")"
+
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "device_id", devices
)
+ sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
+
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
- txn.execute(sql, [user_id] + devices)
+ txn.execute(sql, [user_id] + list(args))
for row in txn:
# Only insert into the local inbox if the device exists on
# this server
@@ -411,19 +452,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
return rows
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
-
- @defer.inlineCallbacks
- def _background_drop_index_device_inbox(self, progress, batch_size):
- def reindex_txn(conn):
- txn = conn.cursor()
- txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
- txn.close()
-
- yield self.runWithConnection(reindex_txn)
-
- yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
-
- defer.returnValue(1)
diff --git a/synapse/storage/devices.py b/synapse/storage/data_stores/main/devices.py
index d102e07372..8af5f7de54 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,11 +22,24 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.api.errors import StoreError
+from synapse.api.errors import Codes, StoreError
+from synapse.logging.opentracing import (
+ get_active_span_text_map,
+ set_tag,
+ trace,
+ whitelisted_homeserver,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import Cache, SQLBaseStore, db_to_json
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import Database
+from synapse.types import Collection, get_verify_key_from_cross_signing_key
+from synapse.util.caches.descriptors import (
+ Cache,
+ cached,
+ cachedInlineCallbacks,
+ cachedList,
+)
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@@ -35,7 +50,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
- """Retrieve a device.
+ """Retrieve a device. Only returns devices that are not marked as
+ hidden.
Args:
user_id (str): The ID of the user which owns the device
@@ -45,16 +61,17 @@ class DeviceWorkerStore(SQLBaseStore):
Raises:
StoreError: if the device is not found
"""
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
- """Retrieve all of a user's registered devices.
+ """Retrieve all of a user's registered devices. Only returns devices
+ that are not marked as hidden.
Args:
user_id (str):
@@ -63,23 +80,29 @@ class DeviceWorkerStore(SQLBaseStore):
containing "device_id", "user_id" and "display_name" for each
device.
"""
- devices = yield self._simple_select_list(
+ devices = yield self.db.simple_select_list(
table="devices",
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
)
- defer.returnValue({d["device_id"]: d for d in devices})
+ return {d["device_id"]: d for d in devices}
+ @trace
@defer.inlineCallbacks
- def get_devices_by_remote(self, destination, from_stream_id, limit):
- """Get stream of updates to send to remote servers
+ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ """Get a stream of device updates to send to the given remote server.
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ limit (int): Maximum number of device updates to return
Returns:
- Deferred[tuple[int, list[dict]]]:
+ Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
- response), and the list of updates
+ response), and the list of updates, where each update is a pair of EDU
+ type and EDU contents
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -87,7 +110,7 @@ class DeviceWorkerStore(SQLBaseStore):
destination, int(from_stream_id)
)
if not has_changed:
- defer.returnValue((now_stream_id, []))
+ return now_stream_id, []
# We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
@@ -99,9 +122,9 @@ class DeviceWorkerStore(SQLBaseStore):
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
- updates = yield self.runInteraction(
- "get_devices_by_remote",
- self._get_devices_by_remote_txn,
+ updates = yield self.db.runInteraction(
+ "get_device_updates_by_remote",
+ self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
@@ -110,7 +133,38 @@ class DeviceWorkerStore(SQLBaseStore):
# Return an empty list if there are no updates
if not updates:
- defer.returnValue((now_stream_id, []))
+ return now_stream_id, []
+
+ # get the cross-signing keys of the users in the list, so that we can
+ # determine which of the device changes were cross-signing keys
+ users = {r[0] for r in updates}
+ master_key_by_user = {}
+ self_signing_key_by_user = {}
+ for user in users:
+ cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ master_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
+ cross_signing_key = yield self.get_e2e_cross_signing_key(
+ user, "self_signing"
+ )
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ self_signing_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
@@ -126,16 +180,43 @@ class DeviceWorkerStore(SQLBaseStore):
# (user_id, device_id) entries into a map, with the value being
# the max stream_id across each set of duplicate entries
#
- # maps (user_id, device_id) -> stream_id
+ # maps (user_id, device_id) -> (stream_id, opentracing_context)
# as long as their stream_id does not match that of the last row
+ #
+ # opentracing_context contains the opentracing metadata for the request
+ # that created the poke
+ #
+ # The most recent request's opentracing_context is used as the
+ # context which created the Edu.
+
query_map = {}
- for update in updates:
- if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+ cross_signing_keys_by_user = {}
+ for user_id, device_id, update_stream_id, update_context in updates:
+ if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
- key = (update[0], update[1])
- query_map[key] = max(query_map.get(key, 0), update[2])
+ if (
+ user_id in master_key_by_user
+ and device_id == master_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["master_key"] = master_key_by_user[user_id]["key_info"]
+ elif (
+ user_id in self_signing_key_by_user
+ and device_id == self_signing_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["self_signing_key"] = self_signing_key_by_user[user_id][
+ "key_info"
+ ]
+ else:
+ key = (user_id, device_id)
+
+ previous_update_stream_id, _ = query_map.get(key, (0, None))
+
+ if update_stream_id > previous_update_stream_id:
+ query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
@@ -145,18 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
- if not query_map:
- defer.returnValue((stream_id_cutoff, []))
+ if not query_map and not cross_signing_keys_by_user:
+ return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
- destination,
- from_stream_id,
- query_map,
+ destination, from_stream_id, query_map
)
- defer.returnValue((now_stream_id, results))
+ # add the updated cross-signing keys to the results list
+ for user_id, result in iteritems(cross_signing_keys_by_user):
+ result["user_id"] = user_id
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("org.matrix.signing_key_update", result))
- def _get_devices_by_remote_txn(
+ return now_stream_id, results
+
+ def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
@@ -171,8 +256,9 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
List: List of device updates
"""
+ # get the list of device updates that need to be sent
sql = """
- SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
+ SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
ORDER BY stream_id
LIMIT ?
@@ -182,27 +268,30 @@ class DeviceWorkerStore(SQLBaseStore):
return list(txn)
@defer.inlineCallbacks
- def _get_device_update_edus_by_remote(
- self, destination, from_stream_id, query_map,
- ):
+ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
"""Returns a list of device update EDUs as well as E2EE keys
Args:
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
- query_map (Dict[(str, str): int]): Dictionary mapping
- user_id/device_id to update stream_id
+ query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
+ user_id/device_id to update stream_id and the relevent json-encoded
+ opentracing context
Returns:
List[Dict]: List of objects representing an device update EDU
"""
- devices = yield self.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
- query_map.keys(),
- include_all_devices=True,
- include_deleted_devices=True,
+ devices = (
+ yield self.db.runInteraction(
+ "_get_e2e_device_keys_txn",
+ self._get_e2e_device_keys_txn,
+ query_map.keys(),
+ include_all_devices=True,
+ include_deleted_devices=True,
+ )
+ if query_map
+ else {}
)
results = []
@@ -210,15 +299,16 @@ class DeviceWorkerStore(SQLBaseStore):
# The prev_id for the first row is always the last row before
# `from_stream_id`
prev_id = yield self._get_last_device_update_for_remote_user(
- destination, user_id, from_stream_id,
+ destination, user_id, from_stream_id
)
for device_id, device in iteritems(user_devices):
- stream_id = query_map[(user_id, device_id)]
+ stream_id, opentracing_context = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
+ "org.matrix.opentracing_context": opentracing_context,
}
prev_id = stream_id
@@ -227,18 +317,25 @@ class DeviceWorkerStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
+
+ if "signatures" in device:
+ for sig_user_id, sigs in device["signatures"].items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
else:
result["deleted"] = True
- results.append(result)
+ results.append(("m.device_list_update", result))
- defer.returnValue(results)
+ return results
def _get_last_device_update_for_remote_user(
- self, destination, user_id, from_stream_id,
+ self, destination, user_id, from_stream_id
):
def f(txn):
prev_sent_id_sql = """
@@ -250,12 +347,12 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
- return self.runInteraction("get_last_device_update_for_remote_user", f)
+ return self.db.runInteraction("get_last_device_update_for_remote_user", f)
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
@@ -299,9 +396,45 @@ class DeviceWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (destination, stream_id))
+ @defer.inlineCallbacks
+ def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+ """Persist that a user has made new signatures
+
+ Args:
+ from_user_id (str): the user who made the signatures
+ user_ids (list[str]): the users who were signed
+ """
+
+ with self._device_list_id_gen.get_next() as stream_id:
+ yield self.db.runInteraction(
+ "add_user_sig_change_to_streams",
+ self._add_user_signature_change_txn,
+ from_user_id,
+ user_ids,
+ stream_id,
+ )
+ return stream_id
+
+ def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+ txn.call_after(
+ self._user_signature_stream_cache.entity_has_changed,
+ from_user_id,
+ stream_id,
+ )
+ self.db.simple_insert_txn(
+ txn,
+ "user_signature_stream",
+ values={
+ "stream_id": stream_id,
+ "from_user_id": from_user_id,
+ "user_ids": json.dumps(user_ids),
+ },
+ )
+
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
+ @trace
@defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
@@ -315,11 +448,17 @@ class DeviceWorkerStore(SQLBaseStore):
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
- user_ids = set(user_id for user_id, _ in query_list)
+ user_ids = {user_id for user_id, _ in query_list}
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
- user_ids_in_cache = set(
- user_id for user_id, stream_id in user_map.items() if stream_id
+
+ # We go and check if any of the users need to have their device lists
+ # resynced. If they do then we remove them from the cached list.
+ users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
+ user_ids
)
+ user_ids_in_cache = {
+ user_id for user_id, stream_id in user_map.items() if stream_id
+ } - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
@@ -331,31 +470,34 @@ class DeviceWorkerStore(SQLBaseStore):
device = yield self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
- results[user_id] = yield self._get_cached_devices_for_user(user_id)
+ results[user_id] = yield self.get_cached_devices_for_user(user_id)
- defer.returnValue((user_ids_not_in_cache, results))
+ set_tag("in_cache", results)
+ set_tag("not_in_cache", user_ids_not_in_cache)
+
+ return user_ids_not_in_cache, results
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
- content = yield self._simple_select_one_onecol(
+ content = yield self.db.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
desc="_get_cached_user_device",
)
- defer.returnValue(db_to_json(content))
+ return db_to_json(content)
@cachedInlineCallbacks()
- def _get_cached_devices_for_user(self, user_id):
- devices = yield self._simple_select_list(
+ def get_cached_devices_for_user(self, user_id):
+ devices = yield self.db.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
- desc="_get_cached_devices_for_user",
- )
- defer.returnValue(
- {device["device_id"]: db_to_json(device["content"]) for device in devices}
+ desc="get_cached_devices_for_user",
)
+ return {
+ device["device_id"]: db_to_json(device["content"]) for device in devices
+ }
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
@@ -363,7 +505,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
(stream_id, devices)
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn,
user_id,
@@ -385,6 +527,13 @@ class DeviceWorkerStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
+
+ if "signatures" in device:
+ for sig_user_id, sigs in device["signatures"].items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -395,22 +544,72 @@ class DeviceWorkerStore(SQLBaseStore):
return now_stream_id, []
- @defer.inlineCallbacks
- def get_user_whose_devices_changed(self, from_key):
- """Get set of users whose devices have changed since `from_key`.
+ def get_users_whose_devices_changed(self, from_key, user_ids):
+ """Get set of users whose devices have changed since `from_key` that
+ are in the given list of user_ids.
+
+ Args:
+ from_key (str): The device lists stream token
+ user_ids (Iterable[str])
+
+ Returns:
+ Deferred[set[str]]: The set of user_ids whose devices have changed
+ since `from_key`
"""
from_key = int(from_key)
- changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
- if changed is not None:
- defer.returnValue(set(changed))
- sql = """
- SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
- """
- rows = yield self._execute(
- "get_user_whose_devices_changed", None, sql, from_key
+ # Get set of users who *may* have changed. Users not in the returned
+ # list have definitely not changed.
+ to_check = list(
+ self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+ )
+
+ if not to_check:
+ return defer.succeed(set())
+
+ def _get_users_whose_devices_changed_txn(txn):
+ changes = set()
+
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_stream
+ WHERE stream_id > ?
+ AND
+ """
+
+ for chunk in batch_iter(to_check, 100):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "user_id", chunk
+ )
+ txn.execute(sql + clause, (from_key,) + tuple(args))
+ changes.update(user_id for user_id, in txn)
+
+ return changes
+
+ return self.db.runInteraction(
+ "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
)
- defer.returnValue(set(row[0] for row in rows))
+
+ @defer.inlineCallbacks
+ def get_users_whose_signatures_changed(self, user_id, from_key):
+ """Get the users who have new cross-signing signatures made by `user_id` since
+ `from_key`.
+
+ Args:
+ user_id (str): the user who made the signatures
+ from_key (str): The device lists stream token
+ """
+ from_key = int(from_key)
+ if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
+ sql = """
+ SELECT DISTINCT user_ids FROM user_signature_stream
+ WHERE from_user_id = ? AND stream_id > ?
+ """
+ rows = yield self.db.execute(
+ "get_users_whose_signatures_changed", None, sql, user_id, from_key
+ )
+ return {user for row in rows for user in json.loads(row[0])}
+ else:
+ return set()
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
@@ -426,7 +625,7 @@ class DeviceWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
- return self._execute(
+ return self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@@ -435,7 +634,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@@ -449,7 +648,7 @@ class DeviceWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
- rows = yield self._simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@@ -460,22 +659,45 @@ class DeviceWorkerStore(SQLBaseStore):
results = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})
- defer.returnValue(results)
+ return results
+ @defer.inlineCallbacks
+ def get_user_ids_requiring_device_list_resync(self, user_ids: Collection[str]):
+ """Given a list of remote users return the list of users that we
+ should resync the device lists for.
-class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
- def __init__(self, db_conn, hs):
- super(DeviceStore, self).__init__(db_conn, hs)
+ Returns:
+ Deferred[Set[str]]
+ """
- # Map of (user_id, device_id) -> bool. If there is an entry that implies
- # the device exists.
- self.device_id_exists_cache = Cache(
- name="device_id_exists", keylen=2, max_entries=10000
+ rows = yield self.db.simple_select_many_batch(
+ table="device_lists_remote_resync",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync",
+ )
+
+ return {row["user_id"] for row in rows}
+
+ def mark_remote_user_device_cache_as_stale(self, user_id: str):
+ """Records that the server has reason to believe the cache of the devices
+ for the remote users is out of date.
+ """
+ return self.db.simple_upsert(
+ table="device_lists_remote_resync",
+ keyvalues={"user_id": user_id},
+ values={},
+ insertion_values={"added_ts": self._clock.time_msec()},
+ desc="make_remote_user_device_cache_as_stale",
)
- self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
- self.register_background_index_update(
+class DeviceBackgroundUpdateStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+ self.db.updates.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
table="device_lists_stream",
@@ -483,7 +705,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
# create a unique index on device_lists_remote_cache
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"device_lists_remote_cache_unique_idx",
index_name="device_lists_remote_cache_unique_id",
table="device_lists_remote_cache",
@@ -492,7 +714,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
# And one on device_lists_remote_extremeties
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"device_lists_remote_extremeties_unique_idx",
index_name="device_lists_remote_extremeties_unique_idx",
table="device_lists_remote_extremeties",
@@ -501,12 +723,39 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
# once they complete, we can remove the old non-unique indexes.
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
self._drop_device_list_streams_non_unique_indexes,
)
@defer.inlineCallbacks
+ def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
+ def f(conn):
+ txn = conn.cursor()
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
+ txn.close()
+
+ yield self.db.runWithConnection(f)
+ yield self.db.updates._end_background_update(
+ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
+ )
+ return 1
+
+
+class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceStore, self).__init__(database, db_conn, hs)
+
+ # Map of (user_id, device_id) -> bool. If there is an entry that implies
+ # the device exists.
+ self.device_id_exists_cache = Cache(
+ name="device_id_exists", keylen=2, max_entries=10000
+ )
+
+ self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
+
+ @defer.inlineCallbacks
def store_device(self, user_id, device_id, initial_device_display_name):
"""Ensure the given device is known; add it to the store if not
@@ -518,24 +767,39 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
defer.Deferred: boolean whether the device was inserted or an
existing device existed with that ID.
+ Raises:
+ StoreError: if the device is already in use
"""
key = (user_id, device_id)
if self.device_id_exists_cache.get(key, None):
- defer.returnValue(False)
+ return False
try:
- inserted = yield self._simple_insert(
+ inserted = yield self.db.simple_insert(
"devices",
values={
"user_id": user_id,
"device_id": device_id,
"display_name": initial_device_display_name,
+ "hidden": False,
},
desc="store_device",
or_ignore=True,
)
+ if not inserted:
+ # if the device already exists, check if it's a real device, or
+ # if the device ID is reserved by something else
+ hidden = yield self.db.simple_select_one_onecol(
+ "devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="hidden",
+ )
+ if hidden:
+ raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
self.device_id_exists_cache.prefill(key, True)
- defer.returnValue(inserted)
+ return inserted
+ except StoreError:
+ raise
except Exception as e:
logger.error(
"store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -560,9 +824,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
defer.Deferred
"""
- yield self._simple_delete_one(
+ yield self.db.simple_delete_one(
table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device",
)
@@ -578,18 +842,19 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
defer.Deferred
"""
- yield self._simple_delete_many(
+ yield self.db.simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
desc="delete_devices",
)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
def update_device(self, user_id, device_id, new_display_name=None):
- """Update a device.
+ """Update a device. Only updates the device if it is not marked as
+ hidden.
Args:
user_id (str): The ID of the user which owns the device
@@ -606,9 +871,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates,
desc="update_device",
)
@@ -617,7 +882,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
- yield self._simple_delete(
+ yield self.db.simple_delete(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
desc="mark_remote_user_device_list_as_unsubscribed",
@@ -641,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
Deferred[None]
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@@ -654,7 +919,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self, txn, user_id, device_id, content, stream_id
):
if content.get("deleted"):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -662,7 +927,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else:
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -673,12 +938,12 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
- txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -702,7 +967,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
Deferred[None]
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@@ -711,11 +976,11 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
@@ -728,13 +993,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
],
)
- txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -744,13 +1009,20 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
lock=False,
)
+ # If we're replacing the remote user's device list cache presumably
+ # we've done a full resync, so we remove the entry that says we need
+ # to resync
+ self.db.simple_delete_txn(
+ txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
+ )
+
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_device_change_to_streams",
self._add_device_change_txn,
user_id,
@@ -758,7 +1030,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
hosts,
stream_id,
)
- defer.returnValue(stream_id)
+ return stream_id
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
@@ -783,7 +1055,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
[(user_id, device_id, stream_id) for device_id in device_ids],
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
@@ -792,7 +1064,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
],
)
- self._simple_insert_many_txn(
+ context = get_active_span_text_map()
+
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
@@ -803,6 +1077,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"device_id": device_id,
"sent": False,
"ts": now,
+ "opentracing_context": json.dumps(context)
+ if whitelisted_homeserver(destination)
+ else "{}",
}
for destination in hosts
for device_id in device_ids
@@ -852,19 +1129,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
return run_as_background_process(
"prune_old_outbound_device_pokes",
- self.runInteraction,
+ self.db.runInteraction,
"_prune_old_outbound_device_pokes",
_prune_txn,
)
-
- @defer.inlineCallbacks
- def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
- def f(conn):
- txn = conn.cursor()
- txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
- txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
- txn.close()
-
- yield self.runWithConnection(f)
- yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
- defer.returnValue(1)
diff --git a/synapse/storage/directory.py b/synapse/storage/data_stores/main/directory.py
index 201bbd430c..c9e7de7d12 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -18,10 +18,9 @@ from collections import namedtuple
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
-from ._base import SQLBaseStore
-
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
@@ -37,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore):
Deferred: results in namedtuple with keys "room_id" and
"servers" or None if no association can be found
"""
- room_id = yield self._simple_select_one_onecol(
+ room_id = yield self.db.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
@@ -46,10 +45,9 @@ class DirectoryWorkerStore(SQLBaseStore):
)
if not room_id:
- defer.returnValue(None)
- return
+ return None
- servers = yield self._simple_select_onecol(
+ servers = yield self.db.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
@@ -57,13 +55,12 @@ class DirectoryWorkerStore(SQLBaseStore):
)
if not servers:
- defer.returnValue(None)
- return
+ return None
- defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers))
+ return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@@ -72,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore):
@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
- return self._simple_select_onecol(
+ return self.db.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
@@ -96,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore):
"""
def alias_txn(txn):
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"room_aliases",
{
@@ -106,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore):
},
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[
@@ -120,20 +117,22 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
- ret = yield self.runInteraction("create_room_alias_association", alias_txn)
+ ret = yield self.db.runInteraction(
+ "create_room_alias_association", alias_txn
+ )
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
- room_id = yield self.runInteraction(
+ room_id = yield self.db.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
- defer.returnValue(room_id)
+ return room_id
def _delete_room_alias_txn(self, txn, room_alias):
txn.execute(
@@ -171,6 +170,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
- return self.runInteraction(
+ return self.db.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 521936e3b0..84594cf0a9 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,55 +19,14 @@ import json
from twisted.internet import defer
from synapse.api.errors import StoreError
-
-from ._base import SQLBaseStore
+from synapse.logging.opentracing import log_kv, trace
+from synapse.storage._base import SQLBaseStore
class EndToEndRoomKeyStore(SQLBaseStore):
@defer.inlineCallbacks
- def get_e2e_room_key(self, user_id, version, room_id, session_id):
- """Get the encrypted E2E room key for a given session from a given
- backup version of room_keys. We only store the 'best' room key for a given
- session at a given time, as determined by the handler.
-
- Args:
- user_id(str): the user whose backup we're querying
- version(str): the version ID of the backup for the set of keys we're querying
- room_id(str): the ID of the room whose keys we're querying.
- This is a bit redundant as it's implied by the session_id, but
- we include for consistency with the rest of the API.
- session_id(str): the session whose room_key we're querying.
-
- Returns:
- A deferred dict giving the session_data and message metadata for
- this room key.
- """
-
- row = yield self._simple_select_one(
- table="e2e_room_keys",
- keyvalues={
- "user_id": user_id,
- "version": version,
- "room_id": room_id,
- "session_id": session_id,
- },
- retcols=(
- "first_message_index",
- "forwarded_count",
- "is_verified",
- "session_data",
- ),
- desc="get_e2e_room_key",
- )
-
- row["session_data"] = json.loads(row["session_data"])
-
- defer.returnValue(row)
-
- @defer.inlineCallbacks
- def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
- """Replaces or inserts the encrypted E2E room key for a given session in
- a given backup
+ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+ """Replaces the encrypted E2E room key for a given session in a given backup
Args:
user_id(str): the user whose backup we're setting
@@ -78,34 +38,73 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
- yield self._simple_upsert(
+ yield self.db.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
+ "version": version,
"room_id": room_id,
"session_id": session_id,
},
- values={
- "version": version,
- "first_message_index": room_key['first_message_index'],
- "forwarded_count": room_key['forwarded_count'],
- "is_verified": room_key['is_verified'],
- "session_data": json.dumps(room_key['session_data']),
+ updatevalues={
+ "first_message_index": room_key["first_message_index"],
+ "forwarded_count": room_key["forwarded_count"],
+ "is_verified": room_key["is_verified"],
+ "session_data": json.dumps(room_key["session_data"]),
},
- lock=False,
+ desc="update_e2e_room_key",
+ )
+
+ @defer.inlineCallbacks
+ def add_e2e_room_keys(self, user_id, version, room_keys):
+ """Bulk add room keys to a given backup.
+
+ Args:
+ user_id (str): the user whose backup we're adding to
+ version (str): the version ID of the backup for the set of keys we're adding to
+ room_keys (iterable[(str, str, dict)]): the keys to add, in the form
+ (roomID, sessionID, keyData)
+ """
+
+ values = []
+ for (room_id, session_id, room_key) in room_keys:
+ values.append(
+ {
+ "user_id": user_id,
+ "version": version,
+ "room_id": room_id,
+ "session_id": session_id,
+ "first_message_index": room_key["first_message_index"],
+ "forwarded_count": room_key["forwarded_count"],
+ "is_verified": room_key["is_verified"],
+ "session_data": json.dumps(room_key["session_data"]),
+ }
+ )
+ log_kv(
+ {
+ "message": "Set room key",
+ "room_id": room_id,
+ "session_id": session_id,
+ "room_key": room_key,
+ }
+ )
+
+ yield self.db.simple_insert_many(
+ table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
+ @trace
@defer.inlineCallbacks
def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
Args:
- user_id(str): the user whose backup we're querying
- version(str): the version ID of the backup for the set of keys we're querying
- room_id(str): Optional. the ID of the room whose keys we're querying, if any.
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup for the set of keys we're querying
+ room_id (str): Optional. the ID of the room whose keys we're querying, if any.
If not specified, we return the keys for all the rooms in the backup.
- session_id(str): Optional. the session whose room_key we're querying, if any.
+ session_id (str): Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we return all the keys in this version of
the backup (or for the specified room)
@@ -118,15 +117,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
try:
version = int(version)
except ValueError:
- defer.returnValue({'rooms': {}})
+ return {"rooms": {}}
keyvalues = {"user_id": user_id, "version": version}
if room_id:
- keyvalues['room_id'] = room_id
+ keyvalues["room_id"] = room_id
if session_id:
- keyvalues['session_id'] = session_id
+ keyvalues["session_id"] = session_id
- rows = yield self._simple_select_list(
+ rows = yield self.db.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
@@ -141,18 +140,108 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="get_e2e_room_keys",
)
- sessions = {'rooms': {}}
+ sessions = {"rooms": {}}
for row in rows:
- room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}})
- room_entry['sessions'][row['session_id']] = {
+ room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
+ room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
"is_verified": row["is_verified"],
"session_data": json.loads(row["session_data"]),
}
- defer.returnValue(sessions)
+ return sessions
+
+ def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+ """Get multiple room keys at a time. The difference between this function and
+ get_e2e_room_keys is that this function can be used to retrieve
+ multiple specific keys at a time, whereas get_e2e_room_keys is used for
+ getting all the keys in a backup version, all the keys for a room, or a
+ specific key.
+
+ Args:
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup we're querying about
+ room_keys (dict[str, dict[str, iterable[str]]]): a map from
+ room ID -> {"session": [session ids]} indicating the session IDs
+ that we want to query
+
+ Returns:
+ Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
+ """
+
+ return self.db.runInteraction(
+ "get_e2e_room_keys_multi",
+ self._get_e2e_room_keys_multi_txn,
+ user_id,
+ version,
+ room_keys,
+ )
+
+ @staticmethod
+ def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+ if not room_keys:
+ return {}
+
+ where_clauses = []
+ params = [user_id, version]
+ for room_id, room in room_keys.items():
+ sessions = list(room["sessions"])
+ if not sessions:
+ continue
+ params.append(room_id)
+ params.extend(sessions)
+ where_clauses.append(
+ "(room_id = ? AND session_id IN (%s))"
+ % (",".join(["?" for _ in sessions]),)
+ )
+
+ # check if we're actually querying something
+ if not where_clauses:
+ return {}
+
+ sql = """
+ SELECT room_id, session_id, first_message_index, forwarded_count,
+ is_verified, session_data
+ FROM e2e_room_keys
+ WHERE user_id = ? AND version = ? AND (%s)
+ """ % (
+ " OR ".join(where_clauses)
+ )
+
+ txn.execute(sql, params)
+
+ ret = {}
+
+ for row in txn:
+ room_id = row[0]
+ session_id = row[1]
+ ret.setdefault(room_id, {})
+ ret[room_id][session_id] = {
+ "first_message_index": row[2],
+ "forwarded_count": row[3],
+ "is_verified": row[4],
+ "session_data": json.loads(row[5]),
+ }
+
+ return ret
+
+ def count_e2e_room_keys(self, user_id, version):
+ """Get the number of keys in a backup version.
+
+ Args:
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup we're querying about
+ """
+
+ return self.db.simple_select_one_onecol(
+ table="e2e_room_keys",
+ keyvalues={"user_id": user_id, "version": version},
+ retcol="COUNT(*)",
+ desc="count_e2e_room_keys",
+ )
+ @trace
@defer.inlineCallbacks
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
@@ -174,11 +263,11 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues = {"user_id": user_id, "version": int(version)}
if room_id:
- keyvalues['room_id'] = room_id
+ keyvalues["room_id"] = room_id
if session_id:
- keyvalues['session_id'] = session_id
+ keyvalues["session_id"] = session_id
- yield self._simple_delete(
+ yield self.db.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@@ -191,7 +280,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
row = txn.fetchone()
if not row:
- raise StoreError(404, 'No current backup version')
+ raise StoreError(404, "No current backup version")
return row[0]
def get_e2e_room_keys_version_info(self, user_id, version=None):
@@ -209,6 +298,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version(str)
algorithm(str)
auth_data(object): opaque dict supplied by the client
+ etag(int): tag of the keys in the backup
"""
def _get_e2e_room_keys_version_info_txn(txn):
@@ -222,20 +312,23 @@ class EndToEndRoomKeyStore(SQLBaseStore):
# it isn't there.
raise StoreError(404, "No row found")
- result = self._simple_select_one_txn(
+ result = self.db.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
- retcols=("version", "algorithm", "auth_data"),
+ retcols=("version", "algorithm", "auth_data", "etag"),
)
result["auth_data"] = json.loads(result["auth_data"])
result["version"] = str(result["version"])
+ if result["etag"] is None:
+ result["etag"] = 0
return result
- return self.runInteraction(
+ return self.db.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
+ @trace
def create_e2e_room_keys_version(self, user_id, info):
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
@@ -255,11 +348,11 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
current_version = txn.fetchone()[0]
if current_version is None:
- current_version = '0'
+ current_version = "0"
new_version = str(int(current_version) + 1)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="e2e_room_keys_versions",
values={
@@ -272,26 +365,40 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return new_version
- return self.runInteraction(
+ return self.db.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
- def update_e2e_room_keys_version(self, user_id, version, info):
+ @trace
+ def update_e2e_room_keys_version(
+ self, user_id, version, info=None, version_etag=None
+ ):
"""Update a given backup version
Args:
user_id(str): the user whose backup version we're updating
version(str): the version ID of the backup version we're updating
- info(dict): the new backup version info to store
+ info (dict): the new backup version info to store. If None, then
+ the backup version info is not updated
+ version_etag (Optional[int]): etag of the keys in the backup. If
+ None, then the etag is not updated
"""
+ updatevalues = {}
- return self._simple_update(
- table="e2e_room_keys_versions",
- keyvalues={"user_id": user_id, "version": version},
- updatevalues={"auth_data": json.dumps(info["auth_data"])},
- desc="update_e2e_room_keys_version",
- )
+ if info is not None and "auth_data" in info:
+ updatevalues["auth_data"] = json.dumps(info["auth_data"])
+ if version_etag is not None:
+ updatevalues["etag"] = version_etag
+
+ if updatevalues:
+ return self.db.simple_update(
+ table="e2e_room_keys_versions",
+ keyvalues={"user_id": user_id, "version": version},
+ updatevalues=updatevalues,
+ desc="update_e2e_room_keys_version",
+ )
+ @trace
def delete_e2e_room_keys_version(self, user_id, version=None):
"""Delete a given backup version of the user's room keys.
Doesn't delete their actual key data.
@@ -308,16 +415,24 @@ class EndToEndRoomKeyStore(SQLBaseStore):
def _delete_e2e_room_keys_version_txn(txn):
if version is None:
this_version = self._get_current_version(txn, user_id)
+ if this_version is None:
+ raise StoreError(404, "No current backup version")
else:
this_version = version
- return self._simple_update_one_txn(
+ self.db.simple_delete_txn(
+ txn,
+ table="e2e_room_keys",
+ keyvalues={"user_id": user_id, "version": this_version},
+ )
+
+ return self.db.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
new file mode 100644
index 0000000000..001a53f9b4
--- /dev/null
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -0,0 +1,771 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, List
+
+from six import iteritems
+
+from canonicaljson import encode_canonical_json, json
+
+from twisted.enterprise.adbapi import Connection
+from twisted.internet import defer
+
+from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.util.caches.descriptors import cached, cachedList
+
+
+class EndToEndKeyWorkerStore(SQLBaseStore):
+ @trace
+ @defer.inlineCallbacks
+ def get_e2e_device_keys(
+ self, query_list, include_all_devices=False, include_deleted_devices=False
+ ):
+ """Fetch a list of device keys.
+ Args:
+ query_list(list): List of pairs of user_ids and device_ids.
+ include_all_devices (bool): whether to include entries for devices
+ that don't have device keys
+ include_deleted_devices (bool): whether to include null entries for
+ devices which no longer exist (but were in the query_list).
+ This option only takes effect if include_all_devices is true.
+ Returns:
+ Dict mapping from user-id to dict mapping from device_id to
+ key data. The key data will be a dict in the same format as the
+ DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
+ """
+ set_tag("query_list", query_list)
+ if not query_list:
+ return {}
+
+ results = yield self.db.runInteraction(
+ "get_e2e_device_keys",
+ self._get_e2e_device_keys_txn,
+ query_list,
+ include_all_devices,
+ include_deleted_devices,
+ )
+
+ # Build the result structure, un-jsonify the results, and add the
+ # "unsigned" section
+ rv = {}
+ for user_id, device_keys in iteritems(results):
+ rv[user_id] = {}
+ for device_id, device_info in iteritems(device_keys):
+ r = db_to_json(device_info.pop("key_json"))
+ r["unsigned"] = {}
+ display_name = device_info["device_display_name"]
+ if display_name is not None:
+ r["unsigned"]["device_display_name"] = display_name
+ if "signatures" in device_info:
+ for sig_user_id, sigs in device_info["signatures"].items():
+ r.setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+ rv[user_id][device_id] = r
+
+ return rv
+
+ @trace
+ def _get_e2e_device_keys_txn(
+ self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+ ):
+ set_tag("include_all_devices", include_all_devices)
+ set_tag("include_deleted_devices", include_deleted_devices)
+
+ query_clauses = []
+ query_params = []
+ signature_query_clauses = []
+ signature_query_params = []
+
+ if include_all_devices is False:
+ include_deleted_devices = False
+
+ if include_deleted_devices:
+ deleted_devices = set(query_list)
+
+ for (user_id, device_id) in query_list:
+ query_clause = "user_id = ?"
+ query_params.append(user_id)
+ signature_query_clause = "target_user_id = ?"
+ signature_query_params.append(user_id)
+
+ if device_id is not None:
+ query_clause += " AND device_id = ?"
+ query_params.append(device_id)
+ signature_query_clause += " AND target_device_id = ?"
+ signature_query_params.append(device_id)
+
+ signature_query_clause += " AND user_id = ?"
+ signature_query_params.append(user_id)
+
+ query_clauses.append(query_clause)
+ signature_query_clauses.append(signature_query_clause)
+
+ sql = (
+ "SELECT user_id, device_id, "
+ " d.display_name AS device_display_name, "
+ " k.key_json"
+ " FROM devices d"
+ " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
+ " WHERE %s AND NOT d.hidden"
+ ) % (
+ "LEFT" if include_all_devices else "INNER",
+ " OR ".join("(" + q + ")" for q in query_clauses),
+ )
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ result = {}
+ for row in rows:
+ if include_deleted_devices:
+ deleted_devices.remove((row["user_id"], row["device_id"]))
+ result.setdefault(row["user_id"], {})[row["device_id"]] = row
+
+ if include_deleted_devices:
+ for user_id, device_id in deleted_devices:
+ result.setdefault(user_id, {})[device_id] = None
+
+ # get signatures on the device
+ signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
+ " OR ".join("(" + q + ")" for q in signature_query_clauses)
+ )
+
+ txn.execute(signature_sql, signature_query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ # add each cross-signing signature to the correct device in the result dict.
+ for row in rows:
+ signing_user_id = row["user_id"]
+ signing_key_id = row["key_id"]
+ target_user_id = row["target_user_id"]
+ target_device_id = row["target_device_id"]
+ signature = row["signature"]
+
+ target_user_result = result.get(target_user_id)
+ if not target_user_result:
+ continue
+
+ target_device_result = target_user_result.get(target_device_id)
+ if not target_device_result:
+ # note that target_device_result will be None for deleted devices.
+ continue
+
+ target_device_signatures = target_device_result.setdefault("signatures", {})
+ signing_user_signatures = target_device_signatures.setdefault(
+ signing_user_id, {}
+ )
+ signing_user_signatures[signing_key_id] = signature
+
+ log_kv(result)
+ return result
+
+ @defer.inlineCallbacks
+ def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+ """Retrieve a number of one-time keys for a user
+
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ key_ids(list[str]): list of key ids (excluding algorithm) to
+ retrieve
+
+ Returns:
+ deferred resolving to Dict[(str, str), str]: map from (algorithm,
+ key_id) to json string for key
+ """
+
+ rows = yield self.db.simple_select_many_batch(
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ retcols=("algorithm", "key_id", "key_json"),
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="add_e2e_one_time_keys_check",
+ )
+ result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
+ log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
+ return result
+
+ @defer.inlineCallbacks
+ def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+ """Insert some new one time keys for a device. Errors if any of the
+ keys already exist.
+
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ time_now(long): insertion time to record (ms since epoch)
+ new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+ (algorithm, key_id, key json)
+ """
+
+ def _add_e2e_one_time_keys(txn):
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("new_keys", new_keys)
+ # We are protected from race between lookup and insertion due to
+ # a unique constraint. If there is a race of two calls to
+ # `add_e2e_one_time_keys` then they'll conflict and we will only
+ # insert one set.
+ self.db.simple_insert_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ values=[
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ "key_id": key_id,
+ "ts_added_ms": time_now,
+ "key_json": json_bytes,
+ }
+ for algorithm, key_id, json_bytes in new_keys
+ ],
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ yield self.db.runInteraction(
+ "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
+ )
+
+ @cached(max_entries=10000)
+ def count_e2e_one_time_keys(self, user_id, device_id):
+ """ Count the number of one time keys the server has for a device
+ Returns:
+ Dict mapping from algorithm to number of keys for that algorithm.
+ """
+
+ def _count_e2e_one_time_keys(txn):
+ sql = (
+ "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ?"
+ " GROUP BY algorithm"
+ )
+ txn.execute(sql, (user_id, device_id))
+ result = {}
+ for algorithm, key_count in txn:
+ result[algorithm] = key_count
+ return result
+
+ return self.db.runInteraction(
+ "count_e2e_one_time_keys", _count_e2e_one_time_keys
+ )
+
+ def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
+ """Returns a user's cross-signing key.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ user_id (str): the user whose key is being requested
+ key_type (str): the type of key that is being requested: either 'master'
+ for a master key, 'self_signing' for a self-signing key, or
+ 'user_signing' for a user-signing key
+ from_user_id (str): if specified, signatures made by this user on
+ the key will be included in the result
+
+ Returns:
+ dict of the key data or None if not found
+ """
+ sql = (
+ "SELECT keydata "
+ " FROM e2e_cross_signing_keys "
+ " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
+ )
+ txn.execute(sql, (user_id, key_type))
+ row = txn.fetchone()
+ if not row:
+ return None
+ key = json.loads(row[0])
+
+ device_id = None
+ for k in key["keys"].values():
+ device_id = k
+
+ if from_user_id is not None:
+ sql = (
+ "SELECT key_id, signature "
+ " FROM e2e_cross_signing_signatures "
+ " WHERE user_id = ? "
+ " AND target_user_id = ? "
+ " AND target_device_id = ? "
+ )
+ txn.execute(sql, (from_user_id, user_id, device_id))
+ row = txn.fetchone()
+ if row:
+ key.setdefault("signatures", {}).setdefault(from_user_id, {})[
+ row[0]
+ ] = row[1]
+
+ return key
+
+ def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+ """Returns a user's cross-signing key.
+
+ Args:
+ user_id (str): the user whose key is being requested
+ key_type (str): the type of key that is being requested: either 'master'
+ for a master key, 'self_signing' for a self-signing key, or
+ 'user_signing' for a user-signing key
+ from_user_id (str): if specified, signatures made by this user on
+ the self-signing key will be included in the result
+
+ Returns:
+ dict of the key data or None if not found
+ """
+ return self.db.runInteraction(
+ "get_e2e_cross_signing_key",
+ self._get_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ from_user_id,
+ )
+
+ @cached(num_args=1)
+ def _get_bare_e2e_cross_signing_keys(self, user_id):
+ """Dummy function. Only used to make a cache for
+ _get_bare_e2e_cross_signing_keys_bulk.
+ """
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_bare_e2e_cross_signing_keys",
+ list_name="user_ids",
+ num_args=1,
+ )
+ def _get_bare_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str]
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing keys for a set of users. The output of this
+ function should be passed to _get_e2e_cross_signing_signatures_txn if
+ the signatures for the calling user need to be fetched.
+
+ Args:
+ user_ids (list[str]): the users whose keys are being requested
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. If a user's cross-signing keys were not found, either
+ their user ID will not be in the dict, or their user ID will map
+ to None.
+
+ """
+ return self.db.runInteraction(
+ "get_bare_e2e_cross_signing_keys_bulk",
+ self._get_bare_e2e_cross_signing_keys_bulk_txn,
+ user_ids,
+ )
+
+ def _get_bare_e2e_cross_signing_keys_bulk_txn(
+ self, txn: Connection, user_ids: List[str],
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing keys for a set of users. The output of this
+ function should be passed to _get_e2e_cross_signing_signatures_txn if
+ the signatures for the calling user need to be fetched.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ user_ids (list[str]): the users whose keys are being requested
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. If a user's cross-signing keys were not found, their user
+ ID will not be in the dict.
+
+ """
+ result = {}
+
+ batch_size = 100
+ chunks = [
+ user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
+ ]
+ for user_chunk in chunks:
+ sql = """
+ SELECT k.user_id, k.keytype, k.keydata, k.stream_id
+ FROM e2e_cross_signing_keys k
+ INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
+ FROM e2e_cross_signing_keys
+ GROUP BY user_id, keytype) s
+ USING (user_id, stream_id, keytype)
+ WHERE k.user_id IN (%s)
+ """ % (
+ ",".join("?" for u in user_chunk),
+ )
+ query_params = []
+ query_params.extend(user_chunk)
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ for row in rows:
+ user_id = row["user_id"]
+ key_type = row["keytype"]
+ key = json.loads(row["keydata"])
+ user_info = result.setdefault(user_id, {})
+ user_info[key_type] = key
+
+ return result
+
+ def _get_e2e_cross_signing_signatures_txn(
+ self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing signatures made by a user on a set of keys.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ keys (dict[str, dict[str, dict]]): a map of user ID to key type to
+ key data. This dict will be modified to add signatures.
+ from_user_id (str): fetch the signatures made by this user
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. The return value will be the same as the keys argument,
+ with the modifications included.
+ """
+
+ # find out what cross-signing keys (a.k.a. devices) we need to get
+ # signatures for. This is a map of (user_id, device_id) to key type
+ # (device_id is the key's public part).
+ devices = {}
+
+ for user_id, user_info in keys.items():
+ if user_info is None:
+ continue
+ for key_type, key in user_info.items():
+ device_id = None
+ for k in key["keys"].values():
+ device_id = k
+ devices[(user_id, device_id)] = key_type
+
+ device_list = list(devices)
+
+ # split into batches
+ batch_size = 100
+ chunks = [
+ device_list[i : i + batch_size]
+ for i in range(0, len(device_list), batch_size)
+ ]
+ for user_chunk in chunks:
+ sql = """
+ SELECT target_user_id, target_device_id, key_id, signature
+ FROM e2e_cross_signing_signatures
+ WHERE user_id = ?
+ AND (%s)
+ """ % (
+ " OR ".join(
+ "(target_user_id = ? AND target_device_id = ?)" for d in devices
+ )
+ )
+ query_params = [from_user_id]
+ for item in devices:
+ # item is a (user_id, device_id) tuple
+ query_params.extend(item)
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ # and add the signatures to the appropriate keys
+ for row in rows:
+ key_id = row["key_id"]
+ target_user_id = row["target_user_id"]
+ target_device_id = row["target_device_id"]
+ key_type = devices[(target_user_id, target_device_id)]
+ # We need to copy everything, because the result may have come
+ # from the cache. dict.copy only does a shallow copy, so we
+ # need to recursively copy the dicts that will be modified.
+ user_info = keys[target_user_id] = keys[target_user_id].copy()
+ target_user_key = user_info[key_type] = user_info[key_type].copy()
+ if "signatures" in target_user_key:
+ signatures = target_user_key["signatures"] = target_user_key[
+ "signatures"
+ ].copy()
+ if from_user_id in signatures:
+ user_sigs = signatures[from_user_id] = signatures[from_user_id]
+ user_sigs[key_id] = row["signature"]
+ else:
+ signatures[from_user_id] = {key_id: row["signature"]}
+ else:
+ target_user_key["signatures"] = {
+ from_user_id: {key_id: row["signature"]}
+ }
+
+ return keys
+
+ @defer.inlineCallbacks
+ def get_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str], from_user_id: str = None
+ ) -> defer.Deferred:
+ """Returns the cross-signing keys for a set of users.
+
+ Args:
+ user_ids (list[str]): the users whose keys are being requested
+ from_user_id (str): if specified, signatures made by this user on
+ the self-signing keys will be included in the result
+
+ Returns:
+ Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
+ key data. If a user's cross-signing keys were not found, either
+ their user ID will not be in the dict, or their user ID will map
+ to None.
+ """
+
+ result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+
+ if from_user_id:
+ result = yield self.db.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_txn,
+ result,
+ from_user_id,
+ )
+
+ return result
+
+ def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+ """Return a list of changes from the user signature stream to notify remotes.
+ Note that the user signature stream represents when a user signs their
+ device with their user-signing key, which is not published to other
+ users or servers, so no `destination` is needed in the returned
+ list. However, this is needed to poke workers.
+
+ Args:
+ from_key (int): the stream ID to start at (exclusive)
+ to_key (int): the stream ID to end at (inclusive)
+
+ Returns:
+ Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
+ """
+ sql = """
+ SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+ FROM user_signature_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ GROUP BY user_id
+ """
+ return self.db.execute(
+ "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+ )
+
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+ def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+ """Stores device keys for a device. Returns whether there was a change
+ or the keys were already in the database.
+ """
+
+ def _set_e2e_device_keys_txn(txn):
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("time_now", time_now)
+ set_tag("device_keys", device_keys)
+
+ old_key_json = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="key_json",
+ allow_none=True,
+ )
+
+ # In py3 we need old_key_json to match new_key_json type. The DB
+ # returns unicode while encode_canonical_json returns bytes.
+ new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+ if old_key_json == new_key_json:
+ log_kv({"Message": "Device key already stored."})
+ return False
+
+ self.db.simple_upsert_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={"ts_added_ms": time_now, "key_json": new_key_json},
+ )
+ log_kv({"message": "Device keys stored."})
+ return True
+
+ return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
+
+ def claim_e2e_one_time_keys(self, query_list):
+ """Take a list of one time keys out of the database"""
+
+ @trace
+ def _claim_e2e_one_time_keys(txn):
+ sql = (
+ "SELECT key_id, key_json FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " LIMIT 1"
+ )
+ result = {}
+ delete = []
+ for user_id, device_id, algorithm in query_list:
+ user_result = result.setdefault(user_id, {})
+ device_result = user_result.setdefault(device_id, {})
+ txn.execute(sql, (user_id, device_id, algorithm))
+ for key_id, key_json in txn:
+ device_result[algorithm + ":" + key_id] = key_json
+ delete.append((user_id, device_id, algorithm, key_id))
+ sql = (
+ "DELETE FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " AND key_id = ?"
+ )
+ for user_id, device_id, algorithm, key_id in delete:
+ log_kv(
+ {
+ "message": "Executing claim e2e_one_time_keys transaction on database."
+ }
+ )
+ txn.execute(sql, (user_id, device_id, algorithm, key_id))
+ log_kv({"message": "finished executing and invalidating cache"})
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+ return result
+
+ return self.db.runInteraction(
+ "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
+ )
+
+ def delete_e2e_keys_by_device(self, user_id, device_id):
+ def delete_e2e_keys_by_device_txn(txn):
+ log_kv(
+ {
+ "message": "Deleting keys for device",
+ "device_id": device_id,
+ "user_id": user_id,
+ }
+ )
+ self.db.simple_delete_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self.db.simple_delete_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return self.db.runInteraction(
+ "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
+ )
+
+ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+ """Set a user's cross-signing key.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ user_id (str): the user to set the signing key for
+ key_type (str): the type of key that is being set: either 'master'
+ for a master key, 'self_signing' for a self-signing key, or
+ 'user_signing' for a user-signing key
+ key (dict): the key data
+ """
+ # the 'key' dict will look something like:
+ # {
+ # "user_id": "@alice:example.com",
+ # "usage": ["self_signing"],
+ # "keys": {
+ # "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key",
+ # },
+ # "signatures": {
+ # "@alice:example.com": {
+ # "ed25519:base64+master+public+key": "base64+signature"
+ # }
+ # }
+ # }
+ # The "keys" property must only have one entry, which will be the public
+ # key, so we just grab the first value in there
+ pubkey = next(iter(key["keys"].values()))
+
+ # The cross-signing keys need to occupy the same namespace as devices,
+ # since signatures are identified by device ID. So add an entry to the
+ # device table to make sure that we don't have a collision with device
+ # IDs.
+ # We only need to do this for local users, since remote servers should be
+ # responsible for checking this for their own users.
+ if self.hs.is_mine_id(user_id):
+ self.db.simple_insert_txn(
+ txn,
+ "devices",
+ values={
+ "user_id": user_id,
+ "device_id": pubkey,
+ "display_name": key_type + " signing key",
+ "hidden": True,
+ },
+ )
+
+ # and finally, store the key itself
+ with self._cross_signing_id_gen.get_next() as stream_id:
+ self.db.simple_insert_txn(
+ txn,
+ "e2e_cross_signing_keys",
+ values={
+ "user_id": user_id,
+ "keytype": key_type,
+ "keydata": json.dumps(key),
+ "stream_id": stream_id,
+ },
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
+ )
+
+ def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ """Set a user's cross-signing key.
+
+ Args:
+ user_id (str): the user to set the user-signing key for
+ key_type (str): the type of cross-signing key to set
+ key (dict): the key data
+ """
+ return self.db.runInteraction(
+ "add_e2e_cross_signing_key",
+ self._set_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ key,
+ )
+
+ def store_e2e_cross_signing_signatures(self, user_id, signatures):
+ """Stores cross-signing signatures.
+
+ Args:
+ user_id (str): the user who made the signatures
+ signatures (iterable[SignatureListItem]): signatures to add
+ """
+ return self.db.simple_insert_many(
+ "e2e_cross_signing_signatures",
+ [
+ {
+ "user_id": user_id,
+ "key_id": item.signing_key_id,
+ "target_user_id": item.target_user_id,
+ "target_device_id": item.target_device_id,
+ "signature": item.signature,
+ }
+ for item in signatures
+ ],
+ "add_e2e_signing_key",
+ )
diff --git a/synapse/storage/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 09e39c2c28..62d4e9f599 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -12,22 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
-import random
+from typing import Dict, List, Optional, Set, Tuple
-from six.moves import range
from six.moves.queue import Empty, PriorityQueue
-from unpaddedbase64 import encode_base64
-
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.signatures import SignatureWorkerStore
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.database import Database
from synapse.util.caches.descriptors import cached
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@@ -47,37 +47,55 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
- def get_auth_chain_ids(self, event_ids, include_given=False):
+ def get_auth_chain_ids(
+ self,
+ event_ids: List[str],
+ include_given: bool = False,
+ ignore_events: Optional[Set[str]] = None,
+ ):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
+ ignore_events: Set of events to exclude from the returned auth
+ chain. This is useful if the caller will just discard the
+ given events anyway, and saves us from figuring out their auth
+ chains if not required.
Returns:
list of event_ids
"""
- return self.runInteraction(
- "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
+ return self.db.runInteraction(
+ "get_auth_chain_ids",
+ self._get_auth_chain_ids_txn,
+ event_ids,
+ include_given,
+ ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
+ if ignore_events is None:
+ ignore_events = set()
+
if include_given:
results = set(event_ids)
else:
results = set()
- base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
+ base_sql = "SELECT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
new_front = set()
- front_list = list(front)
- chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
- for chunk in chunks:
- txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk)
- new_front.update([r[0] for r in txn])
+ for chunk in batch_iter(front, 100):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", chunk
+ )
+ txn.execute(base_sql + clause, args)
+ new_front.update(r[0] for r in txn)
+ new_front -= ignore_events
new_front -= results
front = new_front
@@ -85,13 +103,161 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
+ def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ """Given sets of state events figure out the auth chain difference (as
+ per state res v2 algorithm).
+
+ This equivalent to fetching the full auth chain for each set of state
+ and returning the events that don't appear in each and every auth
+ chain.
+
+ Returns:
+ Deferred[Set[str]]
+ """
+
+ return self.db.runInteraction(
+ "get_auth_chain_difference",
+ self._get_auth_chain_difference_txn,
+ state_sets,
+ )
+
+ def _get_auth_chain_difference_txn(
+ self, txn, state_sets: List[Set[str]]
+ ) -> Set[str]:
+
+ # Algorithm Description
+ # ~~~~~~~~~~~~~~~~~~~~~
+ #
+ # The idea here is to basically walk the auth graph of each state set in
+ # tandem, keeping track of which auth events are reachable by each state
+ # set. If we reach an auth event we've already visited (via a different
+ # state set) then we mark that auth event and all ancestors as reachable
+ # by the state set. This requires that we keep track of the auth chains
+ # in memory.
+ #
+ # Doing it in a such a way means that we can stop early if all auth
+ # events we're currently walking are reachable by all state sets.
+ #
+ # *Note*: We can't stop walking an event's auth chain if it is reachable
+ # by all state sets. This is because other auth chains we're walking
+ # might be reachable only via the original auth chain. For example,
+ # given the following auth chain:
+ #
+ # A -> C -> D -> E
+ # / /
+ # B -´---------´
+ #
+ # and state sets {A} and {B} then walking the auth chains of A and B
+ # would immediately show that C is reachable by both. However, if we
+ # stopped at C then we'd only reach E via the auth chain of B and so E
+ # would errornously get included in the returned difference.
+ #
+ # The other thing that we do is limit the number of auth chains we walk
+ # at once, due to practical limits (i.e. we can only query the database
+ # with a limited set of parameters). We pick the auth chains we walk
+ # each iteration based on their depth, in the hope that events with a
+ # lower depth are likely reachable by those with higher depths.
+ #
+ # We could use any ordering that we believe would give a rough
+ # topological ordering, e.g. origin server timestamp. If the ordering
+ # chosen is not topological then the algorithm still produces the right
+ # result, but perhaps a bit more inefficiently. This is why it is safe
+ # to use "depth" here.
+
+ initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+ # Dict from events in auth chains to which sets *cannot* reach them.
+ # I.e. if the set is empty then all sets can reach the event.
+ event_to_missing_sets = {
+ event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
+ for event_id in initial_events
+ }
+
+ # We need to get the depth of the initial events for sorting purposes.
+ sql = """
+ SELECT depth, event_id FROM events
+ WHERE %s
+ ORDER BY depth ASC
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", initial_events
+ )
+ txn.execute(sql % (clause,), args)
+
+ # The sorted list of events whose auth chains we should walk.
+ search = txn.fetchall() # type: List[Tuple[int, str]]
+
+ # Map from event to its auth events
+ event_to_auth_events = {} # type: Dict[str, Set[str]]
+
+ base_sql = """
+ SELECT a.event_id, auth_id, depth
+ FROM event_auth AS a
+ INNER JOIN events AS e ON (e.event_id = a.auth_id)
+ WHERE
+ """
+
+ while search:
+ # Check whether all our current walks are reachable by all state
+ # sets. If so we can bail.
+ if all(not event_to_missing_sets[eid] for _, eid in search):
+ break
+
+ # Fetch the auth events and their depths of the N last events we're
+ # currently walking
+ search, chunk = search[:-100], search[-100:]
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
+ )
+ txn.execute(base_sql + clause, args)
+
+ for event_id, auth_event_id, auth_event_depth in txn:
+ event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
+
+ sets = event_to_missing_sets.get(auth_event_id)
+ if sets is None:
+ # First time we're seeing this event, so we add it to the
+ # queue of things to fetch.
+ search.append((auth_event_depth, auth_event_id))
+
+ # Assume that this event is unreachable from any of the
+ # state sets until proven otherwise
+ sets = event_to_missing_sets[auth_event_id] = set(
+ range(len(state_sets))
+ )
+ else:
+ # We've previously seen this event, so look up its auth
+ # events and recursively mark all ancestors as reachable
+ # by the current event's state set.
+ a_ids = event_to_auth_events.get(auth_event_id)
+ while a_ids:
+ new_aids = set()
+ for a_id in a_ids:
+ event_to_missing_sets[a_id].intersection_update(
+ event_to_missing_sets[event_id]
+ )
+
+ b = event_to_auth_events.get(a_id)
+ if b:
+ new_aids.update(b)
+
+ a_ids = new_aids
+
+ # Mark that the auth event is reachable by the approriate sets.
+ sets.intersection_update(event_to_missing_sets[event_id])
+
+ search.sort()
+
+ # Return all events where not all sets can reach them.
+ return {eid for eid, n in event_to_missing_sets.items() if n}
+
def get_oldest_events_in_room(self, room_id):
- return self.runInteraction(
+ return self.db.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
- return self.runInteraction(
+ return self.db.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
@@ -122,7 +288,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns
Deferred[int]
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
@@ -131,20 +297,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
if not rows:
- defer.returnValue(0)
+ return 0
else:
- defer.returnValue(max(row["depth"] for row in rows))
+ return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
- return self._simple_select_onecol_txn(
+ return self.db.simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
- @defer.inlineCallbacks
- def get_prev_events_for_room(self, room_id):
+ def get_prev_events_for_room(self, room_id: str):
"""
Gets a subset of the current forward extremities in the given room.
@@ -155,80 +320,87 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
room_id (str): room_id
Returns:
- Deferred[list[(str, dict[str, str], int)]]
- for each event, a tuple of (event_id, hashes, depth)
- where *hashes* is a map from algorithm to hash.
+ Deferred[List[str]]: the event ids of the forward extremites
+
"""
- res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
- if len(res) > 10:
- # Sort by reverse depth, so we point to the most recent.
- res.sort(key=lambda a: -a[2])
- # we use half of the limit for the actual most recent events, and
- # the other half to randomly point to some of the older events, to
- # make sure that we don't completely ignore the older events.
- res = res[0:5] + random.sample(res[5:], 5)
+ return self.db.runInteraction(
+ "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
+ )
- defer.returnValue(res)
+ def _get_prev_events_for_room_txn(self, txn, room_id: str):
+ # we just use the 10 newest events. Older events will become
+ # prev_events of future events.
- def get_latest_event_ids_and_hashes_in_room(self, room_id):
+ sql = """
+ SELECT e.event_id FROM event_forward_extremities AS f
+ INNER JOIN events AS e USING (event_id)
+ WHERE f.room_id = ?
+ ORDER BY e.depth DESC
+ LIMIT 10
"""
- Gets the current forward extremities in the given room
+
+ txn.execute(sql, (room_id,))
+
+ return [row[0] for row in txn]
+
+ def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
+ """Get the top rooms with at least N extremities.
Args:
- room_id (str): room_id
+ min_count (int): The minimum number of extremities
+ limit (int): The maximum number of rooms to return.
+ room_id_filter (iterable[str]): room_ids to exclude from the results
Returns:
- Deferred[list[(str, dict[str, str], int)]]
- for each event, a tuple of (event_id, hashes, depth)
- where *hashes* is a map from algorithm to hash.
+ Deferred[list]: At most `limit` room IDs that have at least
+ `min_count` extremities, sorted by extremity count.
"""
- return self.runInteraction(
- "get_latest_event_ids_and_hashes_in_room",
- self._get_latest_event_ids_and_hashes_in_room,
- room_id,
+ def _get_rooms_with_many_extremities_txn(txn):
+ where_clause = "1=1"
+ if room_id_filter:
+ where_clause = "room_id NOT IN (%s)" % (
+ ",".join("?" for _ in room_id_filter),
+ )
+
+ sql = """
+ SELECT room_id FROM event_forward_extremities
+ WHERE %s
+ GROUP BY room_id
+ HAVING count(*) > ?
+ ORDER BY count(*) DESC
+ LIMIT ?
+ """ % (
+ where_clause,
+ )
+
+ query_args = list(itertools.chain(room_id_filter, [min_count, limit]))
+ txn.execute(sql, query_args)
+ return [room_id for room_id, in txn]
+
+ return self.db.runInteraction(
+ "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
)
@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
- return self._simple_select_onecol(
+ return self.db.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
desc="get_latest_event_ids_in_room",
)
- def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
- sql = (
- "SELECT e.event_id, e.depth FROM events as e "
- "INNER JOIN event_forward_extremities as f "
- "ON e.event_id = f.event_id "
- "AND e.room_id = f.room_id "
- "WHERE f.room_id = ?"
- )
-
- txn.execute(sql, (room_id,))
-
- results = []
- for event_id, depth in txn.fetchall():
- hashes = self._get_event_reference_hashes_txn(txn, event_id)
- prev_hashes = {
- k: encode_base64(v) for k, v in hashes.items() if k == "sha256"
- }
- results.append((event_id, prev_hashes, depth))
-
- return results
-
def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(self, txn, room_id):
- min_depth = self._simple_select_one_onecol_txn(
+ min_depth = self.db.simple_select_one_onecol_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -294,7 +466,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
@@ -309,7 +481,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit (int)
"""
return (
- self.runInteraction(
+ self.db.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
@@ -321,9 +493,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
def _get_backfill_events(self, txn, room_id, event_list, limit):
- logger.debug(
- "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit
- )
+ logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
event_results = set()
@@ -342,7 +512,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
queue = PriorityQueue()
for event_id in event_list:
- depth = self._simple_select_one_onecol_txn(
+ depth = self.db.simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -374,7 +544,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
@defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events, limit):
- ids = yield self.runInteraction(
+ ids = yield self.db.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id,
@@ -383,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit,
)
events = yield self.get_events_as_list(ids)
- defer.returnValue(events)
+ return events
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@@ -404,7 +574,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
query, (room_id, event_id, False, limit - len(event_results))
)
- new_results = set(t[0] for t in txn) - seen_events
+ new_results = {t[0] for t in txn} - seen_events
new_front |= new_results
seen_events |= new_results
@@ -427,7 +597,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
Deferred[list[str]]
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
@@ -435,7 +605,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
desc="get_successor_events",
)
- defer.returnValue([row["event_id"] for row in rows])
+ return [row["event_id"] for row in rows]
class EventFederationStore(EventFederationWorkerStore):
@@ -450,10 +620,10 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
- def __init__(self, db_conn, hs):
- super(EventFederationStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventFederationStore, self).__init__(database, db_conn, hs)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
@@ -464,10 +634,10 @@ class EventFederationStore(EventFederationWorkerStore):
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id)
- if min_depth and depth >= min_depth:
+ if min_depth is not None and depth >= min_depth:
return
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -479,7 +649,7 @@ class EventFederationStore(EventFederationWorkerStore):
For the given event, update the event edges table and forward and
backward extremities tables.
"""
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_edges",
values=[
@@ -563,13 +733,13 @@ class EventFederationStore(EventFederationWorkerStore):
return run_as_background_process(
"delete_old_forward_extrem_cache",
- self.runInteraction,
+ self.db.runInteraction,
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn,
)
def clean_room_for_join(self, room_id):
- return self.runInteraction(
+ return self.db.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
@@ -613,17 +783,17 @@ class EventFederationStore(EventFederationWorkerStore):
"max_stream_id_exclusive": min_stream_id,
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.EVENT_AUTH_STATE_ONLY, new_progress
)
return min_stream_id >= target_min_stream_id
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
)
if not result:
- yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+ yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY)
- defer.returnValue(batch_size)
+ return batch_size
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index a729f3e067..8eed590929 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage.database import Database
from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
self.stream_ordering_month_ago = None
@@ -79,8 +80,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine,
- after_callbacks=[],
- exception_callbacks=[],
)
self._find_stream_orderings_for_times_txn(cur)
cur.close()
@@ -95,14 +94,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
- ret = yield self.runInteraction(
+ ret = yield self.db.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
- defer.returnValue(ret)
+ return ret
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
@@ -179,8 +178,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
- ret = yield self.runInteraction("get_push_action_users_in_range", f)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
+ return ret
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_http(
@@ -231,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = yield self.runInteraction(
+ after_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@@ -259,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = yield self.runInteraction(
+ no_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@@ -277,11 +276,11 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by stream_ordering, oldest first.
- notifs.sort(key=lambda r: r['stream_ordering'])
+ notifs.sort(key=lambda r: r["stream_ordering"])
# Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit.
- defer.returnValue(notifs[:limit])
+ return notifs[:limit]
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_email(
@@ -331,7 +330,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = yield self.runInteraction(
+ after_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@@ -359,7 +358,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = yield self.runInteraction(
+ no_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@@ -379,10 +378,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by received_ts (most recent first)
- notifs.sort(key=lambda r: -(r['received_ts'] or 0))
+ notifs.sort(key=lambda r: -(r["received_ts"] or 0))
# Now return the first `limit`
- defer.returnValue(notifs[:limit])
+ return notifs[:limit]
def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
"""A fast check to see if there might be something to push for the
@@ -409,7 +408,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
- return self.runInteraction(
+ return self.db.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
@@ -443,7 +442,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _add_push_actions_to_staging_txn(txn):
- # We don't use _simple_insert_many here to avoid the overhead
+ # We don't use simple_insert_many here to avoid the overhead
# of generating lists of dicts.
sql = """
@@ -460,7 +459,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
- return self.runInteraction(
+ return self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
@@ -474,12 +473,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
try:
- res = yield self._simple_delete(
+ res = yield self.db.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
)
- defer.returnValue(res)
+ return res
except Exception:
# this method is called from an exception handler, so propagating
# another exception here really isn't helpful - there's nothing
@@ -491,7 +490,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _find_stream_orderings_for_times(self):
return run_as_background_process(
"event_push_action_stream_orderings",
- self.runInteraction,
+ self.db.runInteraction,
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn,
)
@@ -527,7 +526,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
Deferred[int]: stream ordering of the first event received on/after
the timestamp
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@@ -609,21 +608,38 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end
+ @defer.inlineCallbacks
+ def get_time_of_last_push_action_before(self, stream_ordering):
+ def f(txn):
+ sql = (
+ "SELECT e.received_ts"
+ " FROM event_push_actions AS ep"
+ " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
+ " WHERE ep.stream_ordering > ?"
+ " ORDER BY ep.stream_ordering ASC"
+ " LIMIT 1"
+ )
+ txn.execute(sql, (stream_ordering,))
+ return txn.fetchone()
+
+ result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
+ return result[0] if result else None
+
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
- def __init__(self, db_conn, hs):
- super(EventPushActionsStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventPushActionsStore, self).__init__(database, db_conn, hs)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
index_name="event_push_actions_u_highlight",
table="event_push_actions",
columns=["user_id", "stream_ordering"],
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"event_push_actions_highlights_index",
index_name="event_push_actions_highlights_index",
table="event_push_actions",
@@ -679,7 +695,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
for event, _ in events_and_contexts:
- user_ids = self._simple_select_onecol_txn(
+ user_ids = self.db.simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
keyvalues={"event_id": event.event_id},
@@ -729,29 +745,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- push_actions = yield self.runInteraction("get_push_actions_for_user", f)
+ push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
- defer.returnValue(push_actions)
-
- @defer.inlineCallbacks
- def get_time_of_last_push_action_before(self, stream_ordering):
- def f(txn):
- sql = (
- "SELECT e.received_ts"
- " FROM event_push_actions AS ep"
- " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.stream_ordering > ?"
- " ORDER BY ep.stream_ordering ASC"
- " LIMIT 1"
- )
- txn.execute(sql, (stream_ordering,))
- return txn.fetchone()
-
- result = yield self.runInteraction("get_time_of_last_push_action_before", f)
- defer.returnValue(result[0] if result else None)
+ return push_actions
@defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self):
@@ -759,8 +758,10 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
- result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
- defer.returnValue(result[0] or 0)
+ result = yield self.db.runInteraction(
+ "get_latest_push_action_stream_ordering", f
+ )
+ return result[0] or 0
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here
@@ -832,7 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True:
logger.info("Rotating notifications")
- caught_up = yield self.runInteraction(
+ caught_up = yield self.db.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
@@ -846,7 +847,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
the archiving process has caught up or not.
"""
- old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -865,7 +866,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
stream_row = txn.fetchone()
if stream_row:
- offset_stream_ordering, = stream_row
+ (offset_stream_ordering,) = stream_row
rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering
)
@@ -882,7 +883,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
- old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -914,7 +915,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
# existing table.
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_push_summary",
values=[
diff --git a/synapse/storage/events.py b/synapse/storage/data_stores/main/events.py
index bc3e6de3bf..d593ef47b8 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,8 +17,9 @@
import itertools
import logging
-from collections import OrderedDict, deque, namedtuple
+from collections import Counter as c_counter, OrderedDict, namedtuple
from functools import wraps
+from typing import Dict, List, Tuple
from six import iteritems, text_type
from six.moves import range
@@ -29,24 +30,25 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.errors import SynapseError
+from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.events.utils import prune_event_dict
+from synapse.logging.utils import log_function
+from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.storage.event_federation import EventFederationStore
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.state import StateGroupWorkerStore
-from synapse.types import RoomStreamToken, get_domain_from_id
-from synapse.util import batch_iter
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.storage._base import make_in_list_sql_clause
+from synapse.storage.data_stores.main.event_federation import EventFederationStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.persist_events import DeltaState
+from synapse.types import RoomStreamToken, StateMap, get_domain_from_id
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
-from synapse.util.logutils import log_function
-from synapse.util.metrics import Measure
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@@ -57,22 +59,6 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
-# The number of times we are recalculating the current state
-state_delta_counter = Counter("synapse_storage_events_state_delta", "")
-
-# The number of times we are recalculating state when there is only a
-# single forward extremity
-state_delta_single_event_counter = Counter(
- "synapse_storage_events_state_delta_single_event", ""
-)
-
-# The number of times we are reculating state when we could have resonably
-# calculated the delta when we calculated the state for an event we were
-# persisting.
-state_delta_reuse_delta_counter = Counter(
- "synapse_storage_events_state_delta_reuse_delta", ""
-)
-
def encode_json(json_object):
"""
@@ -84,110 +70,6 @@ def encode_json(json_object):
return out
-class _EventPeristenceQueue(object):
- """Queues up events so that they can be persisted in bulk with only one
- concurrent transaction per room.
- """
-
- _EventPersistQueueItem = namedtuple(
- "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
- )
-
- def __init__(self):
- self._event_persist_queues = {}
- self._currently_persisting_rooms = set()
-
- def add_to_queue(self, room_id, events_and_contexts, backfilled):
- """Add events to the queue, with the given persist_event options.
-
- NB: due to the normal usage pattern of this method, it does *not*
- follow the synapse logcontext rules, and leaves the logcontext in
- place whether or not the returned deferred is ready.
-
- Args:
- room_id (str):
- events_and_contexts (list[(EventBase, EventContext)]):
- backfilled (bool):
-
- Returns:
- defer.Deferred: a deferred which will resolve once the events are
- persisted. Runs its callbacks *without* a logcontext.
- """
- queue = self._event_persist_queues.setdefault(room_id, deque())
- if queue:
- # if the last item in the queue has the same `backfilled` setting,
- # we can just add these new events to that item.
- end_item = queue[-1]
- if end_item.backfilled == backfilled:
- end_item.events_and_contexts.extend(events_and_contexts)
- return end_item.deferred.observe()
-
- deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
-
- queue.append(
- self._EventPersistQueueItem(
- events_and_contexts=events_and_contexts,
- backfilled=backfilled,
- deferred=deferred,
- )
- )
-
- return deferred.observe()
-
- def handle_queue(self, room_id, per_item_callback):
- """Attempts to handle the queue for a room if not already being handled.
-
- The given callback will be invoked with for each item in the queue,
- of type _EventPersistQueueItem. The per_item_callback will continuously
- be called with new items, unless the queue becomnes empty. The return
- value of the function will be given to the deferreds waiting on the item,
- exceptions will be passed to the deferreds as well.
-
- This function should therefore be called whenever anything is added
- to the queue.
-
- If another callback is currently handling the queue then it will not be
- invoked.
- """
-
- if room_id in self._currently_persisting_rooms:
- return
-
- self._currently_persisting_rooms.add(room_id)
-
- @defer.inlineCallbacks
- def handle_queue_loop():
- try:
- queue = self._get_drainining_queue(room_id)
- for item in queue:
- try:
- ret = yield per_item_callback(item)
- except Exception:
- with PreserveLoggingContext():
- item.deferred.errback()
- else:
- with PreserveLoggingContext():
- item.deferred.callback(ret)
- finally:
- queue = self._event_persist_queues.pop(room_id, None)
- if queue:
- self._event_persist_queues[room_id] = queue
- self._currently_persisting_rooms.discard(room_id)
-
- # set handle_queue_loop off in the background
- run_as_background_process("persist_events", handle_queue_loop)
-
- def _get_drainining_queue(self, room_id):
- queue = self._event_persist_queues.setdefault(room_id, deque())
-
- try:
- while True:
- yield queue.popleft()
- except IndexError:
- # Queue has been drained.
- pass
-
-
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@@ -203,11 +85,11 @@ def _retry_on_integrity_error(func):
@defer.inlineCallbacks
def f(self, *args, **kwargs):
try:
- res = yield func(self, *args, **kwargs)
+ res = yield func(self, *args, delete_existing=False, **kwargs)
except self.database_engine.module.IntegrityError:
logger.exception("IntegrityError, retrying.")
res = yield func(self, *args, delete_existing=True, **kwargs)
- defer.returnValue(res)
+ return res
return f
@@ -215,106 +97,101 @@ def _retry_on_integrity_error(func):
# inherits from EventFederationStore so that we can call _update_backward_extremities
# and _handle_mult_prev_events (though arguably those could both be moved in here)
class EventsStore(
- StateGroupWorkerStore,
- EventFederationStore,
- EventsWorkerStore,
- BackgroundUpdateStore,
+ StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
):
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventsStore, self).__init__(database, db_conn, hs)
- def __init__(self, db_conn, hs):
- super(EventsStore, self).__init__(db_conn, hs)
-
- self._event_persist_queue = _EventPeristenceQueue()
- self._state_resolution_handler = hs.get_state_resolution_handler()
+ # Collect metrics on the number of forward extremities that exist.
+ # Counter of number of extremities to count
+ self._current_forward_extremities_amount = c_counter()
- @defer.inlineCallbacks
- def persist_events(self, events_and_contexts, backfilled=False):
- """
- Write events to the database
- Args:
- events_and_contexts: list of tuples of (event, context)
- backfilled (bool): Whether the results are retrieved from federation
- via backfill or not. Used to determine if they're "new" events
- which might update the current state etc.
+ BucketCollector(
+ "synapse_forward_extremities",
+ lambda: self._current_forward_extremities_amount,
+ buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
+ )
- Returns:
- Deferred[int]: the stream ordering of the latest persisted event
- """
- partitioned = {}
- for event, ctx in events_and_contexts:
- partitioned.setdefault(event.room_id, []).append((event, ctx))
-
- deferreds = []
- for room_id, evs_ctxs in iteritems(partitioned):
- d = self._event_persist_queue.add_to_queue(
- room_id, evs_ctxs, backfilled=backfilled
+ # Read the extrems every 60 minutes
+ def read_forward_extremities():
+ # run as a background process to make sure that the database transactions
+ # have a logcontext to report to
+ return run_as_background_process(
+ "read_forward_extremities", self._read_forward_extremities
)
- deferreds.append(d)
- for room_id in partitioned:
- self._maybe_start_persisting(room_id)
+ hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
- yield make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True)
- )
+ def _censor_redactions():
+ return run_as_background_process(
+ "_censor_redactions", self._censor_redactions
+ )
- max_persisted_id = yield self._stream_id_gen.get_current_token()
+ if self.hs.config.redaction_retention_period is not None:
+ hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
- defer.returnValue(max_persisted_id)
+ self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+ self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
- @log_function
- def persist_event(self, event, context, backfilled=False):
- """
-
- Args:
- event (EventBase):
- context (EventContext):
- backfilled (bool):
-
- Returns:
- Deferred: resolves to (int, int): the stream ordering of ``event``,
- and the stream ordering of the latest persisted event
- """
- deferred = self._event_persist_queue.add_to_queue(
- event.room_id, [(event, context)], backfilled=backfilled
- )
-
- self._maybe_start_persisting(event.room_id)
-
- yield make_deferred_yieldable(deferred)
-
- max_persisted_id = yield self._stream_id_gen.get_current_token()
- defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
-
- def _maybe_start_persisting(self, room_id):
- @defer.inlineCallbacks
- def persisting_queue(item):
- with Measure(self._clock, "persist_events"):
- yield self._persist_events(
- item.events_and_contexts, backfilled=item.backfilled
- )
+ def _read_forward_extremities(self):
+ def fetch(txn):
+ txn.execute(
+ """
+ select count(*) c from event_forward_extremities
+ group by room_id
+ """
+ )
+ return txn.fetchall()
- self._event_persist_queue.handle_queue(room_id, persisting_queue)
+ res = yield self.db.runInteraction("read_forward_extremities", fetch)
+ self._current_forward_extremities_amount = c_counter([x[0] for x in res])
@_retry_on_integrity_error
@defer.inlineCallbacks
- def _persist_events(
- self, events_and_contexts, backfilled=False, delete_existing=False
+ def _persist_events_and_state_updates(
+ self,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ current_state_for_room: Dict[str, StateMap[str]],
+ state_delta_for_room: Dict[str, DeltaState],
+ new_forward_extremeties: Dict[str, List[str]],
+ backfilled: bool = False,
+ delete_existing: bool = False,
):
- """Persist events to db
+ """Persist a set of events alongside updates to the current state and
+ forward extremities tables.
Args:
- events_and_contexts (list[(EventBase, EventContext)]):
- backfilled (bool):
- delete_existing (bool):
+ events_and_contexts:
+ current_state_for_room: Map from room_id to the current state of
+ the room based on forward extremities
+ state_delta_for_room: Map from room_id to the delta to apply to
+ room state
+ new_forward_extremities: Map from room_id to list of event IDs
+ that are the new forward extremities of the room.
+ backfilled
+ delete_existing
Returns:
Deferred: resolves when the events have been persisted
"""
- if not events_and_contexts:
- return
+ # We want to calculate the stream orderings as late as possible, as
+ # we only notify after all events with a lesser stream ordering have
+ # been persisted. I.e. if we spend 10s inside the with block then
+ # that will delay all subsequent events from being notified about.
+ # Hence why we do it down here rather than wrapping the entire
+ # function.
+ #
+ # Its safe to do this after calculating the state deltas etc as we
+ # only need to protect the *persistence* of the events. This is to
+ # ensure that queries of the form "fetch events since X" don't
+ # return events and stream positions after events that are still in
+ # flight, as otherwise subsequent requests "fetch event since Y"
+ # will not return those events.
+ #
+ # Note: Multiple instances of this function cannot be in flight at
+ # the same time for the same room.
if backfilled:
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
@@ -328,216 +205,44 @@ class EventsStore(
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
- chunks = [
- events_and_contexts[x : x + 100]
- for x in range(0, len(events_and_contexts), 100)
- ]
-
- for chunk in chunks:
- # We can't easily parallelize these since different chunks
- # might contain the same event. :(
-
- # NB: Assumes that we are only persisting events for one room
- # at a time.
-
- # map room_id->list[event_ids] giving the new forward
- # extremities in each room
- new_forward_extremeties = {}
-
- # map room_id->(type,state_key)->event_id tracking the full
- # state in each room after adding these events.
- # This is simply used to prefill the get_current_state_ids
- # cache
- current_state_for_room = {}
-
- # map room_id->(to_delete, to_insert) where to_delete is a list
- # of type/state keys to remove from current state, and to_insert
- # is a map (type,key)->event_id giving the state delta in each
- # room
- state_delta_for_room = {}
-
- if not backfilled:
- with Measure(self._clock, "_calculate_state_and_extrem"):
- # Work out the new "current state" for each room.
- # We do this by working out what the new extremities are and then
- # calculating the state from that.
- events_by_room = {}
- for event, context in chunk:
- events_by_room.setdefault(event.room_id, []).append(
- (event, context)
- )
-
- for room_id, ev_ctx_rm in iteritems(events_by_room):
- latest_event_ids = yield self.get_latest_event_ids_in_room(
- room_id
- )
- new_latest_event_ids = yield self._calculate_new_extremities(
- room_id, ev_ctx_rm, latest_event_ids
- )
-
- latest_event_ids = set(latest_event_ids)
- if new_latest_event_ids == latest_event_ids:
- # No change in extremities, so no change in state
- continue
-
- # there should always be at least one forward extremity.
- # (except during the initial persistence of the send_join
- # results, in which case there will be no existing
- # extremities, so we'll `continue` above and skip this bit.)
- assert new_latest_event_ids, "No forward extremities left!"
-
- new_forward_extremeties[room_id] = new_latest_event_ids
-
- len_1 = (
- len(latest_event_ids) == 1
- and len(new_latest_event_ids) == 1
- )
- if len_1:
- all_single_prev_not_state = all(
- len(event.prev_event_ids()) == 1
- and not event.is_state()
- for event, ctx in ev_ctx_rm
- )
- # Don't bother calculating state if they're just
- # a long chain of single ancestor non-state events.
- if all_single_prev_not_state:
- continue
-
- state_delta_counter.inc()
- if len(new_latest_event_ids) == 1:
- state_delta_single_event_counter.inc()
-
- # This is a fairly handwavey check to see if we could
- # have guessed what the delta would have been when
- # processing one of these events.
- # What we're interested in is if the latest extremities
- # were the same when we created the event as they are
- # now. When this server creates a new event (as opposed
- # to receiving it over federation) it will use the
- # forward extremities as the prev_events, so we can
- # guess this by looking at the prev_events and checking
- # if they match the current forward extremities.
- for ev, _ in ev_ctx_rm:
- prev_event_ids = set(ev.prev_event_ids())
- if latest_event_ids == prev_event_ids:
- state_delta_reuse_delta_counter.inc()
- break
-
- logger.info("Calculating state delta for room %s", room_id)
- with Measure(
- self._clock, "persist_events.get_new_state_after_events"
- ):
- res = yield self._get_new_state_after_events(
- room_id,
- ev_ctx_rm,
- latest_event_ids,
- new_latest_event_ids,
- )
- current_state, delta_ids = res
-
- # If either are not None then there has been a change,
- # and we need to work out the delta (or use that
- # given)
- if delta_ids is not None:
- # If there is a delta we know that we've
- # only added or replaced state, never
- # removed keys entirely.
- state_delta_for_room[room_id] = ([], delta_ids)
- elif current_state is not None:
- with Measure(
- self._clock, "persist_events.calculate_state_delta"
- ):
- delta = yield self._calculate_state_delta(
- room_id, current_state
- )
- state_delta_for_room[room_id] = delta
-
- # If we have the current_state then lets prefill
- # the cache with it.
- if current_state is not None:
- current_state_for_room[room_id] = current_state
-
- yield self.runInteraction(
- "persist_events",
- self._persist_events_txn,
- events_and_contexts=chunk,
- backfilled=backfilled,
- delete_existing=delete_existing,
- state_delta_for_room=state_delta_for_room,
- new_forward_extremeties=new_forward_extremeties,
- )
- persist_event_counter.inc(len(chunk))
-
- if not backfilled:
- # backfilled events have negative stream orderings, so we don't
- # want to set the event_persisted_position to that.
- synapse.metrics.event_persisted_position.set(
- chunk[-1][0].internal_metadata.stream_ordering
- )
-
- for event, context in chunk:
- if context.app_service:
- origin_type = "local"
- origin_entity = context.app_service.id
- elif self.hs.is_mine_id(event.sender):
- origin_type = "local"
- origin_entity = "*client*"
- else:
- origin_type = "remote"
- origin_entity = get_domain_from_id(event.sender)
-
- event_counter.labels(event.type, origin_type, origin_entity).inc()
-
- for room_id, new_state in iteritems(current_state_for_room):
- self.get_current_state_ids.prefill((room_id,), new_state)
-
- for room_id, latest_event_ids in iteritems(new_forward_extremeties):
- self.get_latest_event_ids_in_room.prefill(
- (room_id,), list(latest_event_ids)
- )
-
- @defer.inlineCallbacks
- def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
- """Calculates the new forward extremities for a room given events to
- persist.
-
- Assumes that we are only persisting events for one room at a time.
- """
-
- # we're only interested in new events which aren't outliers and which aren't
- # being rejected.
- new_events = [
- event
- for event, ctx in event_contexts
- if not event.internal_metadata.is_outlier()
- and not ctx.rejected
- and not event.internal_metadata.is_soft_failed()
- ]
-
- # start with the existing forward extremities
- result = set(latest_event_ids)
+ yield self.db.runInteraction(
+ "persist_events",
+ self._persist_events_txn,
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ delete_existing=delete_existing,
+ state_delta_for_room=state_delta_for_room,
+ new_forward_extremeties=new_forward_extremeties,
+ )
+ persist_event_counter.inc(len(events_and_contexts))
- # add all the new events to the list
- result.update(event.event_id for event in new_events)
+ if not backfilled:
+ # backfilled events have negative stream orderings, so we don't
+ # want to set the event_persisted_position to that.
+ synapse.metrics.event_persisted_position.set(
+ events_and_contexts[-1][0].internal_metadata.stream_ordering
+ )
- # Now remove all events which are prev_events of any of the new events
- result.difference_update(
- e_id for event in new_events for e_id in event.prev_event_ids()
- )
+ for event, context in events_and_contexts:
+ if context.app_service:
+ origin_type = "local"
+ origin_entity = context.app_service.id
+ elif self.hs.is_mine_id(event.sender):
+ origin_type = "local"
+ origin_entity = "*client*"
+ else:
+ origin_type = "remote"
+ origin_entity = get_domain_from_id(event.sender)
- # Remove any events which are prev_events of any existing events.
- existing_prevs = yield self._get_events_which_are_prevs(result)
- result.difference_update(existing_prevs)
+ event_counter.labels(event.type, origin_type, origin_entity).inc()
- # Finally handle the case where the new events have soft-failed prev
- # events. If they do we need to remove them and their prev events,
- # otherwise we end up with dangling extremities.
- existing_prevs = yield self._get_prevs_before_rejected(
- e_id for event in new_events for e_id in event.prev_event_ids()
- )
- result.difference_update(existing_prevs)
+ for room_id, new_state in iteritems(current_state_for_room):
+ self.get_current_state_ids.prefill((room_id,), new_state)
- defer.returnValue(result)
+ for room_id, latest_event_ids in iteritems(new_forward_extremeties):
+ self.get_latest_event_ids_in_room.prefill(
+ (room_id,), list(latest_event_ids)
+ )
@defer.inlineCallbacks
def _get_events_which_are_prevs(self, event_ids):
@@ -560,28 +265,24 @@ class EventsStore(
LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id)
WHERE
- prev_event_id IN (%s)
- AND NOT events.outlier
+ NOT events.outlier
AND rejections.event_id IS NULL
- """ % (
- ",".join("?" for _ in batch),
- )
+ AND
+ """
- txn.execute(sql, batch)
- results.extend(
- r[0]
- for r in txn
- if not json.loads(r[1]).get("soft_failed")
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "prev_event_id", batch
)
+ txn.execute(sql + clause, args)
+ results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
+
for chunk in batch_iter(event_ids, 100):
- yield self.runInteraction(
- "_get_events_which_are_prevs",
- _get_events_which_are_prevs_txn,
- chunk,
+ yield self.db.runInteraction(
+ "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def _get_prevs_before_rejected(self, event_ids):
@@ -620,13 +321,15 @@ class EventsStore(
LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id)
WHERE
- event_id IN (%s)
- AND NOT events.outlier
- """ % (
- ",".join("?" for _ in to_recursively_check),
+ NOT events.outlier
+ AND
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "event_id", to_recursively_check
)
- txn.execute(sql, to_recursively_check)
+ txn.execute(sql + clause, args)
to_recursively_check = []
for event_id, prev_event_id, metadata, rejected in txn:
@@ -639,205 +342,21 @@ class EventsStore(
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
- yield self.runInteraction(
- "_get_prevs_before_rejected",
- _get_prevs_before_rejected_txn,
- chunk,
+ yield self.db.runInteraction(
+ "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
- defer.returnValue(existing_prevs)
-
- @defer.inlineCallbacks
- def _get_new_state_after_events(
- self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
- ):
- """Calculate the current state dict after adding some new events to
- a room
-
- Args:
- room_id (str):
- room to which the events are being added. Used for logging etc
-
- events_context (list[(EventBase, EventContext)]):
- events and contexts which are being added to the room
-
- old_latest_event_ids (iterable[str]):
- the old forward extremities for the room.
-
- new_latest_event_ids (iterable[str]):
- the new forward extremities for the room.
-
- Returns:
- Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
- Returns a tuple of two state maps, the first being the full new current
- state and the second being the delta to the existing current state.
- If both are None then there has been no change.
-
- If there has been a change then we only return the delta if its
- already been calculated. Conversely if we do know the delta then
- the new current state is only returned if we've already calculated
- it.
- """
- # map from state_group to ((type, key) -> event_id) state map
- state_groups_map = {}
-
- # Map from (prev state group, new state group) -> delta state dict
- state_group_deltas = {}
-
- for ev, ctx in events_context:
- if ctx.state_group is None:
- # This should only happen for outlier events.
- if not ev.internal_metadata.is_outlier():
- raise Exception(
- "Context for new event %s has no state "
- "group" % (ev.event_id,)
- )
- continue
-
- if ctx.state_group in state_groups_map:
- continue
-
- # We're only interested in pulling out state that has already
- # been cached in the context. We'll pull stuff out of the DB later
- # if necessary.
- current_state_ids = ctx.get_cached_current_state_ids()
- if current_state_ids is not None:
- state_groups_map[ctx.state_group] = current_state_ids
-
- if ctx.prev_group:
- state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
-
- # We need to map the event_ids to their state groups. First, let's
- # check if the event is one we're persisting, in which case we can
- # pull the state group from its context.
- # Otherwise we need to pull the state group from the database.
-
- # Set of events we need to fetch groups for. (We know none of the old
- # extremities are going to be in events_context).
- missing_event_ids = set(old_latest_event_ids)
-
- event_id_to_state_group = {}
- for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding.
- for ev, ctx in events_context:
- if event_id == ev.event_id and ctx.state_group is not None:
- event_id_to_state_group[event_id] = ctx.state_group
- break
- else:
- # If we couldn't find it, then we'll need to pull
- # the state from the database
- missing_event_ids.add(event_id)
-
- if missing_event_ids:
- # Now pull out the state groups for any missing events from DB
- event_to_groups = yield self._get_state_group_for_events(missing_event_ids)
- event_id_to_state_group.update(event_to_groups)
-
- # State groups of old_latest_event_ids
- old_state_groups = set(
- event_id_to_state_group[evid] for evid in old_latest_event_ids
- )
-
- # State groups of new_latest_event_ids
- new_state_groups = set(
- event_id_to_state_group[evid] for evid in new_latest_event_ids
- )
-
- # If they old and new groups are the same then we don't need to do
- # anything.
- if old_state_groups == new_state_groups:
- defer.returnValue((None, None))
-
- if len(new_state_groups) == 1 and len(old_state_groups) == 1:
- # If we're going from one state group to another, lets check if
- # we have a delta for that transition. If we do then we can just
- # return that.
-
- new_state_group = next(iter(new_state_groups))
- old_state_group = next(iter(old_state_groups))
-
- delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
- if delta_ids is not None:
- # We have a delta from the existing to new current state,
- # so lets just return that. If we happen to already have
- # the current state in memory then lets also return that,
- # but it doesn't matter if we don't.
- new_state = state_groups_map.get(new_state_group)
- defer.returnValue((new_state, delta_ids))
-
- # Now that we have calculated new_state_groups we need to get
- # their state IDs so we can resolve to a single state set.
- missing_state = new_state_groups - set(state_groups_map)
- if missing_state:
- group_to_state = yield self._get_state_for_groups(missing_state)
- state_groups_map.update(group_to_state)
-
- if len(new_state_groups) == 1:
- # If there is only one state group, then we know what the current
- # state is.
- defer.returnValue((state_groups_map[new_state_groups.pop()], None))
-
- # Ok, we need to defer to the state handler to resolve our state sets.
-
- state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
-
- events_map = {ev.event_id: ev for ev, _ in events_context}
-
- # We need to get the room version, which is in the create event.
- # Normally that'd be in the database, but its also possible that we're
- # currently trying to persist it.
- room_version = None
- for ev, _ in events_context:
- if ev.type == EventTypes.Create and ev.state_key == "":
- room_version = ev.content.get("room_version", "1")
- break
-
- if not room_version:
- room_version = yield self.get_room_version(room_id)
-
- logger.debug("calling resolve_state_groups from preserve_events")
- res = yield self._state_resolution_handler.resolve_state_groups(
- room_id,
- room_version,
- state_groups,
- events_map,
- state_res_store=StateResolutionStore(self),
- )
-
- defer.returnValue((res.state, None))
-
- @defer.inlineCallbacks
- def _calculate_state_delta(self, room_id, current_state):
- """Calculate the new state deltas for a room.
-
- Assumes that we are only persisting events for one room at a time.
-
- Returns:
- tuple[list, dict] (to_delete, to_insert): where to_delete are the
- type/state_keys to remove from current_state_events and `to_insert`
- are the updates to current_state_events.
- """
- existing_state = yield self.get_current_state_ids(room_id)
-
- to_delete = [key for key in existing_state if key not in current_state]
-
- to_insert = {
- key: ev_id
- for key, ev_id in iteritems(current_state)
- if ev_id != existing_state.get(key)
- }
-
- defer.returnValue((to_delete, to_insert))
+ return existing_prevs
@log_function
def _persist_events_txn(
self,
- txn,
- events_and_contexts,
- backfilled,
- delete_existing=False,
- state_delta_for_room={},
- new_forward_extremeties={},
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ backfilled: bool,
+ delete_existing: bool = False,
+ state_delta_for_room: Dict[str, DeltaState] = {},
+ new_forward_extremeties: Dict[str, List[str]] = {},
):
"""Insert some number of room events into the necessary database tables.
@@ -846,21 +365,16 @@ class EventsStore(
whether the event was rejected.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- events_and_contexts (list[(EventBase, EventContext)]):
- events to persist
- backfilled (bool): True if the events were backfilled
- delete_existing (bool): True to purge existing table rows for the
- events from the database. This is useful when retrying due to
+ txn
+ events_and_contexts: events to persist
+ backfilled: True if the events were backfilled
+ delete_existing True to purge existing table rows for the events
+ from the database. This is useful when retrying due to
IntegrityError.
- state_delta_for_room (dict[str, (list, dict)]):
- The current-state delta for each room. For each room, a tuple
- (to_delete, to_insert), being a list of type/state keys to be
- removed from the current state, and a state set to be added to
- the current state.
- new_forward_extremeties (dict[str, list[str]]):
- The new forward extremities for each room. For each room, a
- list of the event ids which are the forward extremities.
+ state_delta_for_room: The current-state delta for each room.
+ new_forward_extremetie: The new forward extremities for each room.
+ For each room, a list of the event ids which are the forward
+ extremities.
"""
all_events_and_contexts = events_and_contexts
@@ -868,8 +382,6 @@ class EventsStore(
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
-
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -912,7 +424,7 @@ class EventsStore(
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@@ -943,88 +455,135 @@ class EventsStore(
backfilled=backfilled,
)
- def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
- for room_id, current_state_tuple in iteritems(state_delta_by_room):
- to_delete, to_insert = current_state_tuple
-
- # First we add entries to the current_state_delta_stream. We
- # do this before updating the current_state_events table so
- # that we can use it to calculate the `prev_event_id`. (This
- # allows us to not have to pull out the existing state
- # unnecessarily).
- #
- # The stream_id for the update is chosen to be the minimum of the stream_ids
- # for the batch of the events that we are persisting; that means we do not
- # end up in a situation where workers see events before the
- # current_state_delta updates.
- #
- sql = """
- INSERT INTO current_state_delta_stream
- (stream_id, room_id, type, state_key, event_id, prev_event_id)
- SELECT ?, ?, ?, ?, ?, (
- SELECT event_id FROM current_state_events
- WHERE room_id = ? AND type = ? AND state_key = ?
+ # We call this last as it assumes we've inserted the events into
+ # room_memberships, where applicable.
+ self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
+ def _update_current_state_txn(
+ self,
+ txn: LoggingTransaction,
+ state_delta_by_room: Dict[str, DeltaState],
+ stream_id: int,
+ ):
+ for room_id, delta_state in iteritems(state_delta_by_room):
+ to_delete = delta_state.to_delete
+ to_insert = delta_state.to_insert
+
+ if delta_state.no_longer_in_room:
+ # Server is no longer in the room so we delete the room from
+ # current_state_events, being careful we've already updated the
+ # rooms.room_version column (which gets populated in a
+ # background task).
+ self._upsert_room_version_txn(txn, room_id)
+
+ # Before deleting we populate the current_state_delta_stream
+ # so that async background tasks get told what happened.
+ sql = """
+ INSERT INTO current_state_delta_stream
+ (stream_id, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, room_id, type, state_key, null, event_id
+ FROM current_state_events
+ WHERE room_id = ?
+ """
+ txn.execute(sql, (stream_id, room_id))
+
+ self.db.simple_delete_txn(
+ txn, table="current_state_events", keyvalues={"room_id": room_id},
)
- """
- txn.executemany(
- sql,
- (
- (
- stream_id,
- room_id,
- etype,
- state_key,
- None,
- room_id,
- etype,
- state_key,
+ else:
+ # We're still in the room, so we update the current state as normal.
+
+ # First we add entries to the current_state_delta_stream. We
+ # do this before updating the current_state_events table so
+ # that we can use it to calculate the `prev_event_id`. (This
+ # allows us to not have to pull out the existing state
+ # unnecessarily).
+ #
+ # The stream_id for the update is chosen to be the minimum of the stream_ids
+ # for the batch of the events that we are persisting; that means we do not
+ # end up in a situation where workers see events before the
+ # current_state_delta updates.
+ #
+ sql = """
+ INSERT INTO current_state_delta_stream
+ (stream_id, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, ?, ?, ?, (
+ SELECT event_id FROM current_state_events
+ WHERE room_id = ? AND type = ? AND state_key = ?
)
- for etype, state_key in to_delete
- # We sanity check that we're deleting rather than updating
- if (etype, state_key) not in to_insert
- ),
- )
- txn.executemany(
- sql,
- (
+ """
+ txn.executemany(
+ sql,
(
- stream_id,
- room_id,
- etype,
- state_key,
- ev_id,
- room_id,
- etype,
- state_key,
- )
- for (etype, state_key), ev_id in iteritems(to_insert)
- ),
- )
+ (
+ stream_id,
+ room_id,
+ etype,
+ state_key,
+ to_insert.get((etype, state_key)),
+ room_id,
+ etype,
+ state_key,
+ )
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
+ )
+ # Now we actually update the current_state_events table
- # Now we actually update the current_state_events table
+ txn.executemany(
+ "DELETE FROM current_state_events"
+ " WHERE room_id = ? AND type = ? AND state_key = ?",
+ (
+ (room_id, etype, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
+ )
- txn.executemany(
- "DELETE FROM current_state_events"
- " WHERE room_id = ? AND type = ? AND state_key = ?",
- (
- (room_id, etype, state_key)
- for etype, state_key in itertools.chain(to_delete, to_insert)
- ),
- )
+ # We include the membership in the current state table, hence we do
+ # a lookup when we insert. This assumes that all events have already
+ # been inserted into room_memberships.
+ txn.executemany(
+ """INSERT INTO current_state_events
+ (room_id, type, state_key, event_id, membership)
+ VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+ """,
+ [
+ (room_id, key[0], key[1], ev_id, ev_id)
+ for key, ev_id in iteritems(to_insert)
+ ],
+ )
- self._simple_insert_many_txn(
- txn,
- table="current_state_events",
- values=[
- {
- "event_id": ev_id,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- }
- for key, ev_id in iteritems(to_insert)
- ],
- )
+ # We now update `local_current_membership`. We do this regardless
+ # of whether we're still in the room or not to handle the case where
+ # e.g. we just got banned (where we need to record that fact here).
+
+ # Note: Do we really want to delete rows here (that we do not
+ # subsequently reinsert below)? While technically correct it means
+ # we have no record of the fact the user *was* a member of the
+ # room but got, say, state reset out of it.
+ if to_delete or to_insert:
+ txn.executemany(
+ "DELETE FROM local_current_membership"
+ " WHERE room_id = ? AND user_id = ?",
+ (
+ (room_id, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ if etype == EventTypes.Member and self.is_mine_id(state_key)
+ ),
+ )
+
+ if to_insert:
+ txn.executemany(
+ """INSERT INTO local_current_membership
+ (room_id, user_id, event_id, membership)
+ VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+ """,
+ [
+ (room_id, key[1], ev_id, ev_id)
+ for key, ev_id in to_insert.items()
+ if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+ ],
+ )
txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
@@ -1039,11 +598,11 @@ class EventsStore(
# We find out which membership events we may have deleted
# and which we have added, then we invlidate the caches for all
# those users.
- members_changed = set(
+ members_changed = {
state_key
for ev_type, state_key in itertools.chain(to_delete, to_insert)
if ev_type == EventTypes.Member
- )
+ }
for member in members_changed:
txn.call_after(
@@ -1052,16 +611,45 @@ class EventsStore(
self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+ def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+ """Update the room version in the database based off current state
+ events.
+
+ This is used when we're about to delete current state and we want to
+ ensure that the `rooms.room_version` column is up to date.
+ """
+
+ sql = """
+ SELECT json FROM event_json
+ INNER JOIN current_state_events USING (room_id, event_id)
+ WHERE room_id = ? AND type = ? AND state_key = ?
+ """
+ txn.execute(sql, (room_id, EventTypes.Create, ""))
+ row = txn.fetchone()
+ if row:
+ event_json = json.loads(row[0])
+ content = event_json.get("content", {})
+ creator = content.get("creator")
+ room_version_id = content.get("room_version", RoomVersions.V1.identifier)
+
+ self.db.simple_upsert_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ values={"room_version": room_version_id},
+ insertion_values={"is_public": False, "creator": creator},
+ )
+
def _update_forward_extremities_txn(
self, txn, new_forward_extremities, max_stream_order
):
for room_id, new_extrem in iteritems(new_forward_extremities):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
@@ -1074,7 +662,7 @@ class EventsStore(
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
@@ -1191,16 +779,14 @@ class EventsStore(
metadata_json = encode_json(event.internal_metadata.get_dict())
- sql = (
- "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?"
- )
+ sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
@@ -1210,7 +796,7 @@ class EventsStore(
},
)
- sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?"
+ sql = "UPDATE events SET outlier = ? WHERE event_id = ?"
txn.execute(sql, (False, event.event_id))
# Update the event_backward_extremities table now that this
@@ -1236,15 +822,11 @@ class EventsStore(
"event_reference_hashes",
"event_search",
"event_to_state_groups",
- "guest_access",
- "history_visibility",
"local_invites",
- "room_names",
"state_events",
"rejections",
"redactions",
"room_memberships",
- "topics",
):
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),
@@ -1276,7 +858,7 @@ class EventsStore(
d.pop("redacted_because", None)
return d
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_json",
values=[
@@ -1293,7 +875,7 @@ class EventsStore(
],
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="events",
values=[
@@ -1318,6 +900,18 @@ class EventsStore(
],
)
+ for event, _ in events_and_contexts:
+ if not event.internal_metadata.is_redacted():
+ # If we're persisting an unredacted event we go and ensure
+ # that we mark any redactions that reference this event as
+ # requiring censoring.
+ self.db.simple_update_txn(
+ txn,
+ table="redactions",
+ keyvalues={"redacts": event.event_id},
+ updatevalues={"have_censored": False},
+ )
+
def _store_rejected_events_txn(self, txn, events_and_contexts):
"""Add rows to the 'rejections' table for received events which were
rejected
@@ -1388,29 +982,36 @@ class EventsStore(
for event, _ in events_and_contexts:
if event.type == EventTypes.Name:
- # Insert into the room_names and event_search tables.
+ # Insert into the event_search table.
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
- # Insert into the topics table and event_search table.
+ # Insert into the event_search table.
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
# Insert into the event_search table.
self._store_room_message_txn(txn, event)
- elif event.type == EventTypes.Redaction:
+ elif event.type == EventTypes.Redaction and event.redacts is not None:
# Insert into the redactions table.
self._store_redaction(txn, event)
- elif event.type == EventTypes.RoomHistoryVisibility:
- # Insert into the event_search table.
- self._store_history_visibility_txn(txn, event)
- elif event.type == EventTypes.GuestAccess:
- # Insert into the event_search table.
- self._store_guest_access_txn(txn, event)
elif event.type == EventTypes.Retention:
# Update the room_retention table.
self._store_retention_policy_for_room_txn(txn, event)
self._handle_event_relations(txn, event)
+ # Store the labels for this event.
+ labels = event.content.get(EventContentFields.LABELS)
+ if labels:
+ self.insert_labels_for_event_txn(
+ txn, event.event_id, labels, event.room_id, event.depth
+ )
+
+ if self._ephemeral_messages_enabled:
+ # If there's an expiry timestamp on the event, store it.
+ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+ if isinstance(expiry_ts, int) and not event.is_state():
+ self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
+
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@@ -1446,7 +1047,7 @@ class EventsStore(
state_values.append(vals)
- self._simple_insert_many_txn(txn, table="state_events", values=state_values)
+ self.db.simple_insert_many_txn(txn, table="state_events", values=state_values)
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@@ -1469,11 +1070,15 @@ class EventsStore(
" FROM events as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"] * len(ev_map)),)
+ " WHERE "
+ )
- txn.execute(sql, list(ev_map))
- rows = self.cursor_to_dict(txn)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "e.event_id", list(ev_map)
+ )
+
+ txn.execute(sql + clause, args)
+ rows = self.db.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
@@ -1490,9 +1095,118 @@ class EventsStore(
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
txn.call_after(self._invalidate_get_event_cache, event.redacts)
- txn.execute(
- "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
- (event.event_id, event.redacts),
+
+ self.db.simple_insert_txn(
+ txn,
+ table="redactions",
+ values={
+ "event_id": event.event_id,
+ "redacts": event.redacts,
+ "received_ts": self._clock.time_msec(),
+ },
+ )
+
+ async def _censor_redactions(self):
+ """Censors all redactions older than the configured period that haven't
+ been censored yet.
+
+ By censor we mean update the event_json table with the redacted event.
+ """
+
+ if self.hs.config.redaction_retention_period is None:
+ return
+
+ if not (
+ await self.db.updates.has_completed_background_update(
+ "redactions_have_censored_ts_idx"
+ )
+ ):
+ # We don't want to run this until the appropriate index has been
+ # created.
+ return
+
+ before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+
+ # We fetch all redactions that:
+ # 1. point to an event we have,
+ # 2. has a received_ts from before the cut off, and
+ # 3. we haven't yet censored.
+ #
+ # This is limited to 100 events to ensure that we don't try and do too
+ # much at once. We'll get called again so this should eventually catch
+ # up.
+ sql = """
+ SELECT redactions.event_id, redacts FROM redactions
+ LEFT JOIN events AS original_event ON (
+ redacts = original_event.event_id
+ )
+ WHERE NOT have_censored
+ AND redactions.received_ts <= ?
+ ORDER BY redactions.received_ts ASC
+ LIMIT ?
+ """
+
+ rows = await self.db.execute(
+ "_censor_redactions_fetch", None, sql, before_ts, 100
+ )
+
+ updates = []
+
+ for redaction_id, event_id in rows:
+ redaction_event = await self.get_event(redaction_id, allow_none=True)
+ original_event = await self.get_event(
+ event_id, allow_rejected=True, allow_none=True
+ )
+
+ # The SQL above ensures that we have both the redaction and
+ # original event, so if the `get_event` calls return None it
+ # means that the redaction wasn't allowed. Either way we know that
+ # the result won't change so we mark the fact that we've checked.
+ if (
+ redaction_event
+ and original_event
+ and original_event.internal_metadata.is_redacted()
+ ):
+ # Redaction was allowed
+ pruned_json = encode_json(
+ prune_event_dict(
+ original_event.room_version, original_event.get_dict()
+ )
+ )
+ else:
+ # Redaction wasn't allowed
+ pruned_json = None
+
+ updates.append((redaction_id, event_id, pruned_json))
+
+ def _update_censor_txn(txn):
+ for redaction_id, event_id, pruned_json in updates:
+ if pruned_json:
+ self._censor_event_txn(txn, event_id, pruned_json)
+
+ self.db.simple_update_one_txn(
+ txn,
+ table="redactions",
+ keyvalues={"event_id": redaction_id},
+ updatevalues={"have_censored": True},
+ )
+
+ await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+
+ def _censor_event_txn(self, txn, event_id, pruned_json):
+ """Censor an event by replacing its JSON in the event_json table with the
+ provided pruned JSON.
+
+ Args:
+ txn (LoggingTransaction): The database transaction.
+ event_id (str): The ID of the event to censor.
+ pruned_json (str): The pruned JSON
+ """
+ self.db.simple_update_one_txn(
+ txn,
+ table="event_json",
+ keyvalues={"event_id": event_id},
+ updatevalues={"json": pruned_json},
)
@defer.inlineCallbacks
@@ -1511,11 +1225,11 @@ class EventsStore(
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_messages", _count_messages)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("count_messages", _count_messages)
+ return ret
@defer.inlineCallbacks
def count_daily_sent_messages(self):
@@ -1532,11 +1246,11 @@ class EventsStore(
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
+ return ret
@defer.inlineCallbacks
def count_daily_active_rooms(self):
@@ -1547,11 +1261,11 @@ class EventsStore(
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_daily_active_rooms", _count)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
+ return ret
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
@@ -1602,7 +1316,7 @@ class EventsStore(
return new_event_updates
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
@@ -1647,7 +1361,7 @@ class EventsStore(
return new_event_updates
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
@@ -1740,7 +1454,7 @@ class EventsStore(
backward_ex_outliers,
)
- return self.runInteraction("get_all_new_events", get_all_new_events_txn)
+ return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
def purge_history(self, room_id, token, delete_local_events):
"""Deletes room history before a certain point
@@ -1754,9 +1468,13 @@ class EventsStore(
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
+
+ Returns:
+ Deferred[set[int]]: The set of state groups that are referenced by
+ deleted events.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
@@ -1854,7 +1572,7 @@ class EventsStore(
# We do joins against events_to_purge for e.g. calculating state
# groups to purge, etc., so lets make an index.
- txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)")
+ txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
txn.execute("SELECT event_id, should_delete FROM events_to_purge")
event_rows = txn.fetchall()
@@ -1890,11 +1608,10 @@ class EventsStore(
[(room_id, event_id) for event_id, in new_backwards_extrems],
)
- logger.info("[purge] finding redundant state groups")
+ logger.info("[purge] finding state groups referenced by deleted events")
# Get all state groups that are referenced by events that are to be
- # deleted. We then go and check if they are referenced by other events
- # or state groups, and if not we delete them.
+ # deleted.
txn.execute(
"""
SELECT DISTINCT state_group FROM events_to_purge
@@ -1902,65 +1619,11 @@ class EventsStore(
"""
)
- referenced_state_groups = set(sg for sg, in txn)
+ referenced_state_groups = {sg for sg, in txn}
logger.info(
"[purge] found %i referenced state groups", len(referenced_state_groups)
)
- logger.info("[purge] finding state groups that can be deleted")
-
- _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups)
- state_groups_to_delete, remaining_state_groups = _
-
- logger.info(
- "[purge] found %i state groups to delete", len(state_groups_to_delete)
- )
-
- logger.info(
- "[purge] de-delta-ing %i remaining state groups",
- len(remaining_state_groups),
- )
-
- # Now we turn the state groups that reference to-be-deleted state
- # groups to non delta versions.
- for sg in remaining_state_groups:
- logger.info("[purge] de-delta-ing remaining state group %s", sg)
- curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
- curr_state = curr_state[sg]
-
- self._simple_delete_txn(
- txn, table="state_groups_state", keyvalues={"state_group": sg}
- )
-
- self._simple_delete_txn(
- txn, table="state_group_edges", keyvalues={"state_group": sg}
- )
-
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": sg,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(curr_state)
- ],
- )
-
- logger.info("[purge] removing redundant state groups")
- txn.executemany(
- "DELETE FROM state_groups_state WHERE state_group = ?",
- ((sg,) for sg in state_groups_to_delete),
- )
- txn.executemany(
- "DELETE FROM state_groups WHERE id = ?",
- ((sg,) for sg in state_groups_to_delete),
- )
-
logger.info("[purge] removing events from event_to_state_groups")
txn.execute(
"DELETE FROM event_to_state_groups "
@@ -2032,7 +1695,7 @@ class EventsStore(
""",
(room_id,),
)
- min_depth, = txn.fetchone()
+ (min_depth,) = txn.fetchone()
logger.info("[purge] updating room_depth to %d", min_depth)
@@ -2047,98 +1710,135 @@ class EventsStore(
logger.info("[purge] done")
- def _find_unreferenced_groups_during_purge(self, txn, state_groups):
- """Used when purging history to figure out which state groups can be
- deleted and which need to be de-delta'ed (due to one of its prev groups
- being scheduled for deletion).
+ return referenced_state_groups
+
+ def purge_room(self, room_id):
+ """Deletes all record of a room
Args:
- txn
- state_groups (set[int]): Set of state groups referenced by events
- that are going to be deleted.
+ room_id (str)
Returns:
- tuple[set[int], set[int]]: The set of state groups that can be
- deleted and the set of state groups that need to be de-delta'ed
+ Deferred[List[int]]: The list of state groups to delete.
"""
- # Graph of state group -> previous group
- graph = {}
-
- # Set of events that we have found to be referenced by events
- referenced_groups = set()
-
- # Set of state groups we've already seen
- state_groups_seen = set(state_groups)
-
- # Set of state groups to handle next.
- next_to_search = set(state_groups)
- while next_to_search:
- # We bound size of groups we're looking up at once, to stop the
- # SQL query getting too big
- if len(next_to_search) < 100:
- current_search = next_to_search
- next_to_search = set()
- else:
- current_search = set(itertools.islice(next_to_search, 100))
- next_to_search -= current_search
- # Check if state groups are referenced
- sql = """
- SELECT DISTINCT state_group FROM event_to_state_groups
- LEFT JOIN events_to_purge AS ep USING (event_id)
- WHERE state_group IN (%s) AND ep.event_id IS NULL
- """ % (
- ",".join("?" for _ in current_search),
- )
- txn.execute(sql, list(current_search))
+ return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
- referenced = set(sg for sg, in txn)
- referenced_groups |= referenced
+ def _purge_room_txn(self, txn, room_id):
+ # First we fetch all the state groups that should be deleted, before
+ # we delete that information.
+ txn.execute(
+ """
+ SELECT DISTINCT state_group FROM events
+ INNER JOIN event_to_state_groups USING(event_id)
+ WHERE events.room_id = ?
+ """,
+ (room_id,),
+ )
- # We don't continue iterating up the state group graphs for state
- # groups that are referenced.
- current_search -= referenced
+ state_groups = [row[0] for row in txn]
- rows = self._simple_select_many_txn(
- txn,
- table="state_group_edges",
- column="prev_state_group",
- iterable=current_search,
- keyvalues={},
- retcols=("prev_state_group", "state_group"),
+ # Now we delete tables which lack an index on room_id but have one on event_id
+ for table in (
+ "event_auth",
+ "event_edges",
+ "event_push_actions_staging",
+ "event_reference_hashes",
+ "event_relations",
+ "event_to_state_groups",
+ "redactions",
+ "rejections",
+ "state_events",
+ ):
+ logger.info("[purge] removing %s from %s", room_id, table)
+
+ txn.execute(
+ """
+ DELETE FROM %s WHERE event_id IN (
+ SELECT event_id FROM events WHERE room_id=?
+ )
+ """
+ % (table,),
+ (room_id,),
)
- prevs = set(row["state_group"] for row in rows)
- # We don't bother re-handling groups we've already seen
- prevs -= state_groups_seen
- next_to_search |= prevs
- state_groups_seen |= prevs
+ # and finally, the tables with an index on room_id (or no useful index)
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ # no useful index, but let's clear them anyway
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ "local_current_membership",
+ ):
+ logger.info("[purge] removing %s from %s", room_id, table)
+ txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
- for row in rows:
- # Note: Each state group can have at most one prev group
- graph[row["state_group"]] = row["prev_state_group"]
+ # Other tables we do NOT need to clear out:
+ #
+ # - blocked_rooms
+ # This is important, to make sure that we don't accidentally rejoin a blocked
+ # room after it was purged
+ #
+ # - user_directory
+ # This has a room_id column, but it is unused
+ #
+
+ # Other tables that we might want to consider clearing out include:
+ #
+ # - event_reports
+ # Given that these are intended for abuse management my initial
+ # inclination is to leave them in place.
+ #
+ # - current_state_delta_stream
+ # - ex_outlier_stream
+ # - room_tags_revisions
+ # The problem with these is that they are largeish and there is no room_id
+ # index on them. In any case we should be clearing out 'stream' tables
+ # periodically anyway (#5888)
- to_delete = state_groups_seen - referenced_groups
+ # TODO: we could probably usefully do a bunch of cache invalidation here
- to_dedelta = set()
- for sg in referenced_groups:
- prev_sg = graph.get(sg)
- if prev_sg and prev_sg in to_delete:
- to_dedelta.add(sg)
+ logger.info("[purge] done")
- return to_delete, to_dedelta
+ return state_groups
- @defer.inlineCallbacks
- def is_event_after(self, event_id1, event_id2):
+ async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
- to_1, so_1 = yield self._get_event_ordering(event_id1)
- to_2, so_2 = yield self._get_event_ordering(event_id2)
- defer.returnValue((to_1, so_1) > (to_2, so_2))
+ to_1, so_1 = await self._get_event_ordering(event_id1)
+ to_2, so_2 = await self._get_event_ordering(event_id2)
+ return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
- res = yield self._simple_select_one(
+ res = yield self.db.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@@ -2148,9 +1848,7 @@ class EventsStore(
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
- defer.returnValue(
- (int(res["topological_ordering"]), int(res["stream_ordering"]))
- )
+ return (int(res["topological_ordering"]), int(res["stream_ordering"]))
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
@@ -2163,11 +1861,135 @@ class EventsStore(
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
+ def insert_labels_for_event_txn(
+ self, txn, event_id, labels, room_id, topological_ordering
+ ):
+ """Store the mapping between an event's ID and its labels, with one row per
+ (event_id, label) tuple.
+
+ Args:
+ txn (LoggingTransaction): The transaction to execute.
+ event_id (str): The event's ID.
+ labels (list[str]): A list of text labels.
+ room_id (str): The ID of the room the event was sent to.
+ topological_ordering (int): The position of the event in the room's topology.
+ """
+ return self.db.simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": room_id,
+ "topological_ordering": topological_ordering,
+ }
+ for label in labels
+ ],
+ )
+
+ def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+ """Save the expiry timestamp associated with a given event ID.
+
+ Args:
+ txn (LoggingTransaction): The database transaction to use.
+ event_id (str): The event ID the expiry timestamp is associated with.
+ expiry_ts (int): The timestamp at which to expire (delete) the event.
+ """
+ return self.db.simple_insert_txn(
+ txn=txn,
+ table="event_expiry",
+ values={"event_id": event_id, "expiry_ts": expiry_ts},
+ )
+
+ @defer.inlineCallbacks
+ def expire_event(self, event_id):
+ """Retrieve and expire an event that has expired, and delete its associated
+ expiry timestamp. If the event can't be retrieved, delete its associated
+ timestamp so we don't try to expire it again in the future.
+
+ Args:
+ event_id (str): The ID of the event to delete.
+ """
+ # Try to retrieve the event's content from the database or the event cache.
+ event = yield self.get_event(event_id)
+
+ def delete_expired_event_txn(txn):
+ # Delete the expiry timestamp associated with this event from the database.
+ self._delete_event_expiry_txn(txn, event_id)
+
+ if not event:
+ # If we can't find the event, log a warning and delete the expiry date
+ # from the database so that we don't try to expire it again in the
+ # future.
+ logger.warning(
+ "Can't expire event %s because we don't have it.", event_id
+ )
+ return
+
+ # Prune the event's dict then convert it to JSON.
+ pruned_json = encode_json(
+ prune_event_dict(event.room_version, event.get_dict())
+ )
+
+ # Update the event_json table to replace the event's JSON with the pruned
+ # JSON.
+ self._censor_event_txn(txn, event.event_id, pruned_json)
+
+ # We need to invalidate the event cache entry for this event because we
+ # changed its content in the database. We can't call
+ # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+ # right type.
+ txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+ # Send that invalidation to replication so that other workers also invalidate
+ # the event cache.
+ self._send_invalidation_to_replication(
+ txn, "_get_event_cache", (event.event_id,)
+ )
+
+ yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+ def _delete_event_expiry_txn(self, txn, event_id):
+ """Delete the expiry timestamp associated with an event ID without deleting the
+ actual event.
+
+ Args:
+ txn (LoggingTransaction): The transaction to use to perform the deletion.
+ event_id (str): The event ID to delete the associated expiry timestamp of.
+ """
+ return self.db.simple_delete_txn(
+ txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+ )
+
+ def get_next_event_to_expire(self):
+ """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+ table, or None if there's no more event to expire.
+
+ Returns: Deferred[Optional[Tuple[str, int]]]
+ A tuple containing the event ID as its first element and an expiry timestamp
+ as its second one, if there's at least one row in the event_expiry table.
+ None otherwise.
+ """
+
+ def get_next_event_to_expire_txn(txn):
+ txn.execute(
+ """
+ SELECT event_id, expiry_ts FROM event_expiry
+ ORDER BY expiry_ts ASC LIMIT 1
+ """
+ )
+
+ return txn.fetchone()
+
+ return self.db.runInteraction(
+ desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+ )
+
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 75c1935bf3..f54c8b1ee0 100644
--- a/synapse/storage/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -21,29 +21,31 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.api.constants import EventContentFields
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
logger = logging.getLogger(__name__)
-class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
+class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
- def __init__(self, db_conn, hs):
- super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"event_contains_url_index",
index_name="event_contains_url_index",
table="events",
@@ -54,7 +56,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
# an event_id index on event_search is useful for the purge_history
# api. Plus it means we get to enforce some integrity with a UNIQUE
# clause
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"event_search_event_id_idx",
index_name="event_search_event_id_idx",
table="event_search",
@@ -63,9 +65,37 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
psql_only=True,
)
- self.register_background_update_handler(
- self.DELETE_SOFT_FAILED_EXTREMITIES,
- self._cleanup_extremities_bg_update,
+ self.db.updates.register_background_update_handler(
+ self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
+ )
+
+ self.db.updates.register_background_update_handler(
+ "redactions_received_ts", self._redactions_received_ts
+ )
+
+ # This index gets deleted in `event_fix_redactions_bytes` update
+ self.db.updates.register_background_index_update(
+ "event_fix_redactions_bytes_create_index",
+ index_name="redactions_censored_redacts",
+ table="redactions",
+ columns=["redacts"],
+ where_clause="have_censored",
+ )
+
+ self.db.updates.register_background_update_handler(
+ "event_fix_redactions_bytes", self._event_fix_redactions_bytes
+ )
+
+ self.db.updates.register_background_update_handler(
+ "event_store_labels", self._event_store_labels
+ )
+
+ self.db.updates.register_background_index_update(
+ "redactions_have_censored_ts_idx",
+ index_name="redactions_have_censored_ts",
+ table="redactions",
+ columns=["received_ts"],
+ where_clause="NOT have_censored",
)
@defer.inlineCallbacks
@@ -123,20 +153,22 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(rows),
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
- yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+ yield self.db.updates._end_background_update(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
+ )
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
@@ -167,7 +199,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
- ev_rows = self._simple_select_many_txn(
+ ev_rows = self.db.simple_select_many_txn(
txn,
table="event_json",
column="event_id",
@@ -200,20 +232,22 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(rows_to_update),
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows_to_update)
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
- yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
+ yield self.db.updates._end_background_update(
+ self.EVENT_ORIGIN_SERVER_TS_NAME
+ )
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _cleanup_extremities_bg_update(self, progress, batch_size):
@@ -269,7 +303,8 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
LEFT JOIN events USING (event_id)
LEFT JOIN event_json USING (event_id)
LEFT JOIN rejections USING (event_id)
- """, (batch_size,)
+ """,
+ (batch_size,),
)
for prev_event_id, event_id, metadata, rejected, outlier in txn:
@@ -308,12 +343,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
INNER JOIN event_json USING (event_id)
LEFT JOIN rejections USING (event_id)
WHERE
- prev_event_id IN (%s)
- AND NOT events.outlier
- """ % (
- ",".join("?" for _ in to_check),
+ NOT events.outlier
+ AND
+ """
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "prev_event_id", to_check
)
- txn.execute(sql, to_check)
+ txn.execute(sql + clause, list(args))
for prev_event_id, event_id, metadata, rejected in txn:
if event_id in graph:
@@ -342,7 +378,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
to_delete.intersection_update(original_set)
- deleted = self._simple_delete_many_txn(
+ deleted = self.db.simple_delete_many_txn(
txn=txn,
table="event_forward_extremities",
column="event_id",
@@ -358,22 +394,21 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if deleted:
# We now need to invalidate the caches of these rooms
- rows = self._simple_select_many_txn(
+ rows = self.db.simple_select_many_txn(
txn,
table="events",
column="event_id",
iterable=to_delete,
keyvalues={},
- retcols=("room_id",)
+ retcols=("room_id",),
)
- room_ids = set(row["room_id"] for row in rows)
+ room_ids = {row["room_id"] for row in rows}
for room_id in room_ids:
txn.call_after(
- self.get_latest_event_ids_in_room.invalidate,
- (room_id,)
+ self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
- self._simple_delete_many_txn(
+ self.db.simple_delete_many_txn(
txn=txn,
table="_extremities_to_check",
column="event_id",
@@ -383,19 +418,172 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
return len(original_set)
- num_handled = yield self.runInteraction(
- "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn,
+ num_handled = yield self.db.runInteraction(
+ "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
- yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES)
+ yield self.db.updates._end_background_update(
+ self.DELETE_SOFT_FAILED_EXTREMITIES
+ )
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
- yield self.runInteraction(
- "_cleanup_extremities_bg_update_drop_table",
- _drop_table_txn,
+ yield self.db.runInteraction(
+ "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
+ )
+
+ return num_handled
+
+ @defer.inlineCallbacks
+ def _redactions_received_ts(self, progress, batch_size):
+ """Handles filling out the `received_ts` column in redactions.
+ """
+ last_event_id = progress.get("last_event_id", "")
+
+ def _redactions_received_ts_txn(txn):
+ # Fetch the set of event IDs that we want to update
+ sql = """
+ SELECT event_id FROM redactions
+ WHERE event_id > ?
+ ORDER BY event_id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_event_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ (upper_event_id,) = rows[-1]
+
+ # Update the redactions with the received_ts.
+ #
+ # Note: Not all events have an associated received_ts, so we
+ # fallback to using origin_server_ts. If we for some reason don't
+ # have an origin_server_ts, lets just use the current timestamp.
+ #
+ # We don't want to leave it null, as then we'll never try and
+ # censor those redactions.
+ sql = """
+ UPDATE redactions
+ SET received_ts = (
+ SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events
+ WHERE events.event_id = redactions.event_id
+ )
+ WHERE ? <= event_id AND event_id <= ?
+ """
+
+ txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
+
+ self.db.updates._background_update_progress_txn(
+ txn, "redactions_received_ts", {"last_event_id": upper_event_id}
+ )
+
+ return len(rows)
+
+ count = yield self.db.runInteraction(
+ "_redactions_received_ts", _redactions_received_ts_txn
+ )
+
+ if not count:
+ yield self.db.updates._end_background_update("redactions_received_ts")
+
+ return count
+
+ @defer.inlineCallbacks
+ def _event_fix_redactions_bytes(self, progress, batch_size):
+ """Undoes hex encoded censored redacted event JSON.
+ """
+
+ def _event_fix_redactions_bytes_txn(txn):
+ # This update is quite fast due to new index.
+ txn.execute(
+ """
+ UPDATE event_json
+ SET
+ json = convert_from(json::bytea, 'utf8')
+ FROM redactions
+ WHERE
+ redactions.have_censored
+ AND event_json.event_id = redactions.redacts
+ AND json NOT LIKE '{%';
+ """
)
- defer.returnValue(num_handled)
+ txn.execute("DROP INDEX redactions_censored_redacts")
+
+ yield self.db.runInteraction(
+ "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
+ )
+
+ yield self.db.updates._end_background_update("event_fix_redactions_bytes")
+
+ return 1
+
+ @defer.inlineCallbacks
+ def _event_store_labels(self, progress, batch_size):
+ """Background update handler which will store labels for existing events."""
+ last_event_id = progress.get("last_event_id", "")
+
+ def _event_store_labels_txn(txn):
+ txn.execute(
+ """
+ SELECT event_id, json FROM event_json
+ LEFT JOIN event_labels USING (event_id)
+ WHERE event_id > ? AND label IS NULL
+ ORDER BY event_id LIMIT ?
+ """,
+ (last_event_id, batch_size),
+ )
+
+ results = list(txn)
+
+ nbrows = 0
+ last_row_event_id = ""
+ for (event_id, event_json_raw) in results:
+ try:
+ event_json = json.loads(event_json_raw)
+
+ self.db.simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": event_json["room_id"],
+ "topological_ordering": event_json["depth"],
+ }
+ for label in event_json["content"].get(
+ EventContentFields.LABELS, []
+ )
+ if isinstance(label, str)
+ ],
+ )
+ except Exception as e:
+ logger.warning(
+ "Unable to load event %s (no labels will be imported): %s",
+ event_id,
+ e,
+ )
+
+ nbrows += 1
+ last_row_event_id = event_id
+
+ self.db.updates._background_update_progress_txn(
+ txn, "event_store_labels", {"last_event_id": last_row_event_id}
+ )
+
+ return nbrows
+
+ num_rows = yield self.db.runInteraction(
+ desc="event_store_labels", func=_event_store_labels_txn
+ )
+
+ if not num_rows:
+ yield self.db.updates._end_background_update("event_store_labels")
+
+ return num_rows
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
new file mode 100644
index 0000000000..ca237c6f12
--- /dev/null
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -0,0 +1,965 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+
+import itertools
+import logging
+import threading
+from collections import namedtuple
+from typing import List, Optional
+
+from canonicaljson import json
+from constantly import NamedConstant, Names
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import NotFoundError
+from synapse.api.room_versions import (
+ KNOWN_ROOM_VERSIONS,
+ EventFormatVersions,
+ RoomVersions,
+)
+from synapse.events import make_event_from_dict
+from synapse.events.utils import prune_event
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
+from synapse.types import get_domain_from_id
+from synapse.util.caches.descriptors import Cache
+from synapse.util.iterutils import batch_iter
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+class EventRedactBehaviour(Names):
+ """
+ What to do when retrieving a redacted event from the database.
+ """
+
+ AS_IS = NamedConstant()
+ REDACT = NamedConstant()
+ BLOCK = NamedConstant()
+
+
+class EventsWorkerStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventsWorkerStore, self).__init__(database, db_conn, hs)
+
+ self._get_event_cache = Cache(
+ "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+ )
+
+ self._event_fetch_lock = threading.Condition()
+ self._event_fetch_list = []
+ self._event_fetch_ongoing = 0
+
+ def get_received_ts(self, event_id):
+ """Get received_ts (when it was persisted) for the event.
+
+ Raises an exception for unknown events.
+
+ Args:
+ event_id (str)
+
+ Returns:
+ Deferred[int|None]: Timestamp in milliseconds, or None for events
+ that were persisted before received_ts was implemented.
+ """
+ return self.db.simple_select_one_onecol(
+ table="events",
+ keyvalues={"event_id": event_id},
+ retcol="received_ts",
+ desc="get_received_ts",
+ )
+
+ def get_received_ts_by_stream_pos(self, stream_ordering):
+ """Given a stream ordering get an approximate timestamp of when it
+ happened.
+
+ This is done by simply taking the received ts of the first event that
+ has a stream ordering greater than or equal to the given stream pos.
+ If none exists returns the current time, on the assumption that it must
+ have happened recently.
+
+ Args:
+ stream_ordering (int)
+
+ Returns:
+ Deferred[int]
+ """
+
+ def _get_approximate_received_ts_txn(txn):
+ sql = """
+ SELECT received_ts FROM events
+ WHERE stream_ordering >= ?
+ LIMIT 1
+ """
+
+ txn.execute(sql, (stream_ordering,))
+ row = txn.fetchone()
+ if row and row[0]:
+ ts = row[0]
+ else:
+ ts = self.clock.time_msec()
+
+ return ts
+
+ return self.db.runInteraction(
+ "get_approximate_received_ts", _get_approximate_received_ts_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: bool = False,
+ check_room_id: Optional[str] = None,
+ ):
+ """Get an event from the database by event_id.
+
+ Args:
+ event_id: The event_id of the event to fetch
+
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events (behave as per allow_none
+ if the event is redacted)
+
+ get_prev_content: If True and event is a state event,
+ include the previous states content in the unsigned field.
+
+ allow_rejected: If True, return rejected events. Otherwise,
+ behave as per allow_none.
+
+ allow_none: If True, return None if no event found, if
+ False throw a NotFoundError
+
+ check_room_id: if not None, check the room of the found event.
+ If there is a mismatch, behave as per allow_none.
+
+ Returns:
+ Deferred[EventBase|None]
+ """
+ if not isinstance(event_id, str):
+ raise TypeError("Invalid event event_id %r" % (event_id,))
+
+ events = yield self.get_events_as_list(
+ [event_id],
+ redact_behaviour=redact_behaviour,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ event = events[0] if events else None
+
+ if event is not None and check_room_id is not None:
+ if event.room_id != check_room_id:
+ event = None
+
+ if event is None and not allow_none:
+ raise NotFoundError("Could not find event %s" % (event_id,))
+
+ return event
+
+ @defer.inlineCallbacks
+ def get_events(
+ self,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ ):
+ """Get events from the database
+
+ Args:
+ event_ids: The event_ids of the events to fetch
+
+ redact_behaviour: Determine what to do with a redacted event. Possible
+ values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events (omit them from the response)
+
+ get_prev_content: If True and event is a state event,
+ include the previous states content in the unsigned field.
+
+ allow_rejected: If True, return rejected events. Otherwise,
+ omits rejeted events from the response.
+
+ Returns:
+ Deferred : Dict from event_id to event.
+ """
+ events = yield self.get_events_as_list(
+ event_ids,
+ redact_behaviour=redact_behaviour,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ return {e.event_id: e for e in events}
+
+ @defer.inlineCallbacks
+ def get_events_as_list(
+ self,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ ):
+ """Get events from the database and return in a list in the same order
+ as given by `event_ids` arg.
+
+ Unknown events will be omitted from the response.
+
+ Args:
+ event_ids: The event_ids of the events to fetch
+
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events (omit them from the response)
+
+ get_prev_content: If True and event is a state event,
+ include the previous states content in the unsigned field.
+
+ allow_rejected: If True, return rejected events. Otherwise,
+ omits rejected events from the response.
+
+ Returns:
+ Deferred[list[EventBase]]: List of events fetched from the database. The
+ events are in the same order as `event_ids` arg.
+
+ Note that the returned list may be smaller than the list of event
+ IDs if not all events could be fetched.
+ """
+
+ if not event_ids:
+ return []
+
+ # there may be duplicates so we cast the list to a set
+ event_entry_map = yield self._get_events_from_cache_or_db(
+ set(event_ids), allow_rejected=allow_rejected
+ )
+
+ events = []
+ for event_id in event_ids:
+ entry = event_entry_map.get(event_id, None)
+ if not entry:
+ continue
+
+ if not allow_rejected:
+ assert not entry.event.rejected_reason, (
+ "rejected event returned from _get_events_from_cache_or_db despite "
+ "allow_rejected=False"
+ )
+
+ # We may not have had the original event when we received a redaction, so
+ # we have to recheck auth now.
+
+ if not allow_rejected and entry.event.type == EventTypes.Redaction:
+ if entry.event.redacts is None:
+ # A redacted redaction doesn't have a `redacts` key, in
+ # which case lets just withhold the event.
+ #
+ # Note: Most of the time if the redactions has been
+ # redacted we still have the un-redacted event in the DB
+ # and so we'll still see the `redacts` key. However, this
+ # isn't always true e.g. if we have censored the event.
+ logger.debug(
+ "Withholding redaction event %s as we don't have redacts key",
+ event_id,
+ )
+ continue
+
+ redacted_event_id = entry.event.redacts
+ event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+ original_event_entry = event_map.get(redacted_event_id)
+ if not original_event_entry:
+ # we don't have the redacted event (or it was rejected).
+ #
+ # We assume that the redaction isn't authorized for now; if the
+ # redacted event later turns up, the redaction will be re-checked,
+ # and if it is found valid, the original will get redacted before it
+ # is served to the client.
+ logger.debug(
+ "Withholding redaction event %s since we don't (yet) have the "
+ "original %s",
+ event_id,
+ redacted_event_id,
+ )
+ continue
+
+ original_event = original_event_entry.event
+ if original_event.type == EventTypes.Create:
+ # we never serve redactions of Creates to clients.
+ logger.info(
+ "Withholding redaction %s of create event %s",
+ event_id,
+ redacted_event_id,
+ )
+ continue
+
+ if original_event.room_id != entry.event.room_id:
+ logger.info(
+ "Withholding redaction %s of event %s from a different room",
+ event_id,
+ redacted_event_id,
+ )
+ continue
+
+ if entry.event.internal_metadata.need_to_check_redaction():
+ original_domain = get_domain_from_id(original_event.sender)
+ redaction_domain = get_domain_from_id(entry.event.sender)
+ if original_domain != redaction_domain:
+ # the senders don't match, so this is forbidden
+ logger.info(
+ "Withholding redaction %s whose sender domain %s doesn't "
+ "match that of redacted event %s %s",
+ event_id,
+ redaction_domain,
+ redacted_event_id,
+ original_domain,
+ )
+ continue
+
+ # Update the cache to save doing the checks again.
+ entry.event.internal_metadata.recheck_redaction = False
+
+ event = entry.event
+
+ if entry.redacted_event:
+ if redact_behaviour == EventRedactBehaviour.BLOCK:
+ # Skip this event
+ continue
+ elif redact_behaviour == EventRedactBehaviour.REDACT:
+ event = entry.redacted_event
+
+ events.append(event)
+
+ if get_prev_content:
+ if "replaces_state" in event.unsigned:
+ prev = yield self.get_event(
+ event.unsigned["replaces_state"],
+ get_prev_content=False,
+ allow_none=True,
+ )
+ if prev:
+ event.unsigned = dict(event.unsigned)
+ event.unsigned["prev_content"] = prev.content
+ event.unsigned["prev_sender"] = prev.sender
+
+ return events
+
+ @defer.inlineCallbacks
+ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ """Fetch a bunch of events from the cache or the database.
+
+ If events are pulled from the database, they will be cached for future lookups.
+
+ Unknown events are omitted from the response.
+
+ Args:
+
+ event_ids (Iterable[str]): The event_ids of the events to fetch
+
+ allow_rejected (bool): Whether to include rejected events. If False,
+ rejected events are omitted from the response.
+
+ Returns:
+ Deferred[Dict[str, _EventCacheEntry]]:
+ map from event id to result
+ """
+ event_entry_map = self._get_events_from_cache(
+ event_ids, allow_rejected=allow_rejected
+ )
+
+ missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+ if missing_events_ids:
+ log_ctx = LoggingContext.current_context()
+ log_ctx.record_event_fetch(len(missing_events_ids))
+
+ # Note that _get_events_from_db is also responsible for turning db rows
+ # into FrozenEvents (via _get_event_from_row), which involves seeing if
+ # the events have been redacted, and if so pulling the redaction event out
+ # of the database to check it.
+ #
+ missing_events = yield self._get_events_from_db(
+ missing_events_ids, allow_rejected=allow_rejected
+ )
+
+ event_entry_map.update(missing_events)
+
+ return event_entry_map
+
+ def _invalidate_get_event_cache(self, event_id):
+ self._get_event_cache.invalidate((event_id,))
+
+ def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+ """Fetch events from the caches
+
+ Args:
+ events (Iterable[str]): list of event_ids to fetch
+ allow_rejected (bool): Whether to return events that were rejected
+ update_metrics (bool): Whether to update the cache hit ratio metrics
+
+ Returns:
+ dict of event_id -> _EventCacheEntry for each event_id in cache. If
+ allow_rejected is `False` then there will still be an entry but it
+ will be `None`
+ """
+ event_map = {}
+
+ for event_id in events:
+ ret = self._get_event_cache.get(
+ (event_id,), None, update_metrics=update_metrics
+ )
+ if not ret:
+ continue
+
+ if allow_rejected or not ret.event.rejected_reason:
+ event_map[event_id] = ret
+ else:
+ event_map[event_id] = None
+
+ return event_map
+
+ def _do_fetch(self, conn):
+ """Takes a database connection and waits for requests for events from
+ the _event_fetch_list queue.
+ """
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ self._event_fetch_ongoing -= 1
+ return
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ self._fetch_event_list(conn, event_list)
+
+ def _fetch_event_list(self, conn, event_list):
+ """Handle a load of requests from the _event_fetch_list queue
+
+ Args:
+ conn (twisted.enterprise.adbapi.Connection): database connection
+
+ event_list (list[Tuple[list[str], Deferred]]):
+ The fetch requests. Each entry consists of a list of event
+ ids to be fetched, and a deferred to be completed once the
+ events have been fetched.
+
+ The deferreds are callbacked with a dictionary mapping from event id
+ to event row. Note that it may well contain additional events that
+ were not part of this request.
+ """
+ with Measure(self._clock, "_fetch_event_list"):
+ try:
+ events_to_fetch = {
+ event_id for events, _ in event_list for event_id in events
+ }
+
+ row_dict = self.db.new_transaction(
+ conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
+ )
+
+ # We only want to resolve deferreds from the main thread
+ def fire():
+ for _, d in event_list:
+ d.callback(row_dict)
+
+ with PreserveLoggingContext():
+ self.hs.get_reactor().callFromThread(fire)
+ except Exception as e:
+ logger.exception("do_fetch")
+
+ # We only want to resolve deferreds from the main thread
+ def fire(evs, exc):
+ for _, d in evs:
+ if not d.called:
+ with PreserveLoggingContext():
+ d.errback(exc)
+
+ with PreserveLoggingContext():
+ self.hs.get_reactor().callFromThread(fire, event_list, e)
+
+ @defer.inlineCallbacks
+ def _get_events_from_db(self, event_ids, allow_rejected=False):
+ """Fetch a bunch of events from the database.
+
+ Returned events will be added to the cache for future lookups.
+
+ Unknown events are omitted from the response.
+
+ Args:
+ event_ids (Iterable[str]): The event_ids of the events to fetch
+
+ allow_rejected (bool): Whether to include rejected events. If False,
+ rejected events are omitted from the response.
+
+ Returns:
+ Deferred[Dict[str, _EventCacheEntry]]:
+ map from event id to result. May return extra events which
+ weren't asked for.
+ """
+ fetched_events = {}
+ events_to_fetch = event_ids
+
+ while events_to_fetch:
+ row_map = yield self._enqueue_events(events_to_fetch)
+
+ # we need to recursively fetch any redactions of those events
+ redaction_ids = set()
+ for event_id in events_to_fetch:
+ row = row_map.get(event_id)
+ fetched_events[event_id] = row
+ if row:
+ redaction_ids.update(row["redactions"])
+
+ events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ if events_to_fetch:
+ logger.debug("Also fetching redaction events %s", events_to_fetch)
+
+ # build a map from event_id to EventBase
+ event_map = {}
+ for event_id, row in fetched_events.items():
+ if not row:
+ continue
+ assert row["event_id"] == event_id
+
+ rejected_reason = row["rejected_reason"]
+
+ if not allow_rejected and rejected_reason:
+ continue
+
+ d = json.loads(row["json"])
+ internal_metadata = json.loads(row["internal_metadata"])
+
+ format_version = row["format_version"]
+ if format_version is None:
+ # This means that we stored the event before we had the concept
+ # of a event format version, so it must be a V1 event.
+ format_version = EventFormatVersions.V1
+
+ room_version_id = row["room_version_id"]
+
+ if not room_version_id:
+ # this should only happen for out-of-band membership events
+ if not internal_metadata.get("out_of_band_membership"):
+ logger.warning(
+ "Room %s for event %s is unknown", d["room_id"], event_id
+ )
+ continue
+
+ # take a wild stab at the room version based on the event format
+ if format_version == EventFormatVersions.V1:
+ room_version = RoomVersions.V1
+ elif format_version == EventFormatVersions.V2:
+ room_version = RoomVersions.V3
+ else:
+ room_version = RoomVersions.V5
+ else:
+ room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not room_version:
+ logger.error(
+ "Event %s in room %s has unknown room version %s",
+ event_id,
+ d["room_id"],
+ room_version_id,
+ )
+ continue
+
+ if room_version.event_format != format_version:
+ logger.error(
+ "Event %s in room %s with version %s has wrong format: "
+ "expected %s, was %s",
+ event_id,
+ d["room_id"],
+ room_version_id,
+ room_version.event_format,
+ format_version,
+ )
+ continue
+
+ original_ev = make_event_from_dict(
+ event_dict=d,
+ room_version=room_version,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
+
+ event_map[event_id] = original_ev
+
+ # finally, we can decide whether each one nededs redacting, and build
+ # the cache entries.
+ result_map = {}
+ for event_id, original_ev in event_map.items():
+ redactions = fetched_events[event_id]["redactions"]
+ redacted_event = self._maybe_redact_event_row(
+ original_ev, redactions, event_map
+ )
+
+ cache_entry = _EventCacheEntry(
+ event=original_ev, redacted_event=redacted_event
+ )
+
+ self._get_event_cache.prefill((event_id,), cache_entry)
+ result_map[event_id] = cache_entry
+
+ return result_map
+
+ @defer.inlineCallbacks
+ def _enqueue_events(self, events):
+ """Fetches events from the database using the _event_fetch_list. This
+ allows batch and bulk fetching of events - it allows us to fetch events
+ without having to create a new transaction for each request for events.
+
+ Args:
+ events (Iterable[str]): events to be fetched.
+
+ Returns:
+ Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+ May contain events that weren't requested.
+ """
+
+ events_d = defer.Deferred()
+ with self._event_fetch_lock:
+ self._event_fetch_list.append((events, events_d))
+
+ self._event_fetch_lock.notify()
+
+ if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+ self._event_fetch_ongoing += 1
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process(
+ "fetch_events", self.db.runWithConnection, self._do_fetch
+ )
+
+ logger.debug("Loading %d events: %s", len(events), events)
+ with PreserveLoggingContext():
+ row_map = yield events_d
+ logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
+
+ return row_map
+
+ def _fetch_event_rows(self, txn, event_ids):
+ """Fetch event rows from the database
+
+ Events which are not found are omitted from the result.
+
+ The returned per-event dicts contain the following keys:
+
+ * event_id (str)
+
+ * json (str): json-encoded event structure
+
+ * internal_metadata (str): json-encoded internal metadata dict
+
+ * format_version (int|None): The format of the event. Hopefully one
+ of EventFormatVersions. 'None' means the event predates
+ EventFormatVersions (so the event is format V1).
+
+ * room_version_id (str|None): The version of the room which contains the event.
+ Hopefully one of RoomVersions.
+
+ Due to historical reasons, there may be a few events in the database which
+ do not have an associated room; in this case None will be returned here.
+
+ * rejected_reason (str|None): if the event was rejected, the reason
+ why.
+
+ * redactions (List[str]): a list of event-ids which (claim to) redact
+ this event.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection):
+ event_ids (Iterable[str]): event IDs to fetch
+
+ Returns:
+ Dict[str, Dict]: a map from event id to event info.
+ """
+ event_dict = {}
+ for evs in batch_iter(event_ids, 200):
+ sql = """\
+ SELECT
+ e.event_id,
+ e.internal_metadata,
+ e.json,
+ e.format_version,
+ r.room_version,
+ rej.reason
+ FROM event_json as e
+ LEFT JOIN rooms r USING (room_id)
+ LEFT JOIN rejections as rej USING (event_id)
+ WHERE """
+
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "e.event_id", evs
+ )
+
+ txn.execute(sql + clause, args)
+
+ for row in txn:
+ event_id = row[0]
+ event_dict[event_id] = {
+ "event_id": event_id,
+ "internal_metadata": row[1],
+ "json": row[2],
+ "format_version": row[3],
+ "room_version_id": row[4],
+ "rejected_reason": row[5],
+ "redactions": [],
+ }
+
+ # check for redactions
+ redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
+
+ clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs)
+
+ txn.execute(redactions_sql + clause, args)
+
+ for (redacter, redacted) in txn:
+ d = event_dict.get(redacted)
+ if d:
+ d["redactions"].append(redacter)
+
+ return event_dict
+
+ def _maybe_redact_event_row(self, original_ev, redactions, event_map):
+ """Given an event object and a list of possible redacting event ids,
+ determine whether to honour any of those redactions and if so return a redacted
+ event.
+
+ Args:
+ original_ev (EventBase):
+ redactions (iterable[str]): list of event ids of potential redaction events
+ event_map (dict[str, EventBase]): other events which have been fetched, in
+ which we can look up the redaaction events. Map from event id to event.
+
+ Returns:
+ Deferred[EventBase|None]: if the event should be redacted, a pruned
+ event object. Otherwise, None.
+ """
+ if original_ev.type == "m.room.create":
+ # we choose to ignore redactions of m.room.create events.
+ return None
+
+ for redaction_id in redactions:
+ redaction_event = event_map.get(redaction_id)
+ if not redaction_event or redaction_event.rejected_reason:
+ # we don't have the redaction event, or the redaction event was not
+ # authorized.
+ logger.debug(
+ "%s was redacted by %s but redaction not found/authed",
+ original_ev.event_id,
+ redaction_id,
+ )
+ continue
+
+ if redaction_event.room_id != original_ev.room_id:
+ logger.debug(
+ "%s was redacted by %s but redaction was in a different room!",
+ original_ev.event_id,
+ redaction_id,
+ )
+ continue
+
+ # Starting in room version v3, some redactions need to be
+ # rechecked if we didn't have the redacted event at the
+ # time, so we recheck on read instead.
+ if redaction_event.internal_metadata.need_to_check_redaction():
+ expected_domain = get_domain_from_id(original_ev.sender)
+ if get_domain_from_id(redaction_event.sender) == expected_domain:
+ # This redaction event is allowed. Mark as not needing a recheck.
+ redaction_event.internal_metadata.recheck_redaction = False
+ else:
+ # Senders don't match, so the event isn't actually redacted
+ logger.debug(
+ "%s was redacted by %s but the senders don't match",
+ original_ev.event_id,
+ redaction_id,
+ )
+ continue
+
+ logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id)
+
+ # we found a good redaction event. Redact!
+ redacted_event = prune_event(original_ev)
+ redacted_event.unsigned["redacted_by"] = redaction_id
+
+ # It's fine to add the event directly, since get_pdu_json
+ # will serialise this field correctly
+ redacted_event.unsigned["redacted_because"] = redaction_event
+
+ return redacted_event
+
+ # no valid redaction found for this event
+ return None
+
+ @defer.inlineCallbacks
+ def have_events_in_timeline(self, event_ids):
+ """Given a list of event ids, check if we have already processed and
+ stored them as non outliers.
+ """
+ rows = yield self.db.simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
+ )
+
+ return {r["event_id"] for r in rows}
+
+ @defer.inlineCallbacks
+ def have_seen_events(self, event_ids):
+ """Given a list of event ids, check if we have already processed them.
+
+ Args:
+ event_ids (iterable[str]):
+
+ Returns:
+ Deferred[set[str]]: The events we have already seen.
+ """
+ results = set()
+
+ def have_seen_events_txn(txn, chunk):
+ sql = "SELECT event_id FROM events as e WHERE "
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "e.event_id", chunk
+ )
+ txn.execute(sql + clause, args)
+ for (event_id,) in txn:
+ results.add(event_id)
+
+ # break the input up into chunks of 100
+ input_iterator = iter(event_ids)
+ for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
+ yield self.db.runInteraction(
+ "have_seen_events", have_seen_events_txn, chunk
+ )
+ return results
+
+ def _get_total_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_total_state_event_counts.
+ """
+ # We join against the events table as that has an index on room_id
+ sql = """
+ SELECT COUNT(*) FROM state_events
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id=?
+ """
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_total_state_event_counts(self, room_id):
+ """
+ Gets the total number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.db.runInteraction(
+ "get_total_state_event_counts",
+ self._get_total_state_event_counts_txn,
+ room_id,
+ )
+
+ def _get_current_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_current_state_event_counts.
+ """
+ sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?"
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_current_state_event_counts(self, room_id):
+ """
+ Gets the current number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.db.runInteraction(
+ "get_current_state_event_counts",
+ self._get_current_state_event_counts_txn,
+ room_id,
+ )
+
+ @defer.inlineCallbacks
+ def get_room_complexity(self, room_id):
+ """
+ Get a rough approximation of the complexity of the room. This is used by
+ remote servers to decide whether they wish to join the room or not.
+ Higher complexity value indicates that being in the room will consume
+ more resources.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[dict[str:int]] of complexity version to complexity.
+ """
+ state_events = yield self.get_current_state_event_counts(room_id)
+
+ # Call this one "v1", so we can introduce new ones as we want to develop
+ # it.
+ complexity_v1 = round(state_events / 500, 2)
+
+ return {"v1": complexity_v1}
diff --git a/synapse/storage/filtering.py b/synapse/storage/data_stores/main/filtering.py
index b195dc66a0..342d6622a4 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/data_stores/main/filtering.py
@@ -15,13 +15,10 @@
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from ._base import SQLBaseStore, db_to_json
-
class FilteringStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
@@ -33,7 +30,7 @@ class FilteringStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
- def_json = yield self._simple_select_one_onecol(
+ def_json = yield self.db.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
@@ -41,7 +38,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(db_to_json(def_json))
+ return db_to_json(def_json)
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
@@ -53,12 +50,12 @@ class FilteringStore(SQLBaseStore):
"SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?"
)
- txn.execute(sql, (user_localpart, def_json))
+ txn.execute(sql, (user_localpart, bytearray(def_json)))
filter_id_response = txn.fetchone()
if filter_id_response is not None:
return filter_id_response[0]
- sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
+ sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
@@ -70,8 +67,8 @@ class FilteringStore(SQLBaseStore):
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
)
- txn.execute(sql, (user_localpart, filter_id, def_json))
+ txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
return filter_id
- return self.runInteraction("add_user_filter", _do_txn)
+ return self.db.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/group_server.py b/synapse/storage/data_stores/main/group_server.py
index dce6a43ac1..0963e6c250 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -19,8 +19,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
@@ -28,23 +27,9 @@ _DEFAULT_CATEGORY_ID = ""
_DEFAULT_ROLE_ID = ""
-class GroupServerStore(SQLBaseStore):
- def set_group_join_policy(self, group_id, join_policy):
- """Set the join policy of a group.
-
- join_policy can be one of:
- * "invite"
- * "open"
- """
- return self._simple_update_one(
- table="groups",
- keyvalues={"group_id": group_id},
- updatevalues={"join_policy": join_policy},
- desc="set_group_join_policy",
- )
-
+class GroupServerWorkerStore(SQLBaseStore):
def get_group(self, group_id):
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -66,7 +51,7 @@ class GroupServerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
- return self._simple_select_list(
+ return self.db.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
@@ -76,7 +61,7 @@ class GroupServerStore(SQLBaseStore):
def get_invited_users_in_group(self, group_id):
# TODO: Pagination
- return self._simple_select_onecol(
+ return self.db.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
@@ -90,7 +75,7 @@ class GroupServerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
- return self._simple_select_list(
+ return self.db.simple_select_list(
table="group_rooms",
keyvalues=keyvalues,
retcols=("room_id", "is_public"),
@@ -154,10 +139,372 @@ class GroupServerStore(SQLBaseStore):
return rooms, categories
- return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn)
+ return self.db.runInteraction(
+ "get_rooms_for_summary", _get_rooms_for_summary_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_group_categories(self, group_id):
+ rows = yield self.db.simple_select_list(
+ table="group_room_categories",
+ keyvalues={"group_id": group_id},
+ retcols=("category_id", "is_public", "profile"),
+ desc="get_group_categories",
+ )
+
+ return {
+ row["category_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
+ }
+
+ @defer.inlineCallbacks
+ def get_group_category(self, group_id, category_id):
+ category = yield self.db.simple_select_one(
+ table="group_room_categories",
+ keyvalues={"group_id": group_id, "category_id": category_id},
+ retcols=("is_public", "profile"),
+ desc="get_group_category",
+ )
+
+ category["profile"] = json.loads(category["profile"])
+
+ return category
+
+ @defer.inlineCallbacks
+ def get_group_roles(self, group_id):
+ rows = yield self.db.simple_select_list(
+ table="group_roles",
+ keyvalues={"group_id": group_id},
+ retcols=("role_id", "is_public", "profile"),
+ desc="get_group_roles",
+ )
+
+ return {
+ row["role_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
+ }
+
+ @defer.inlineCallbacks
+ def get_group_role(self, group_id, role_id):
+ role = yield self.db.simple_select_one(
+ table="group_roles",
+ keyvalues={"group_id": group_id, "role_id": role_id},
+ retcols=("is_public", "profile"),
+ desc="get_group_role",
+ )
+
+ role["profile"] = json.loads(role["profile"])
+
+ return role
+
+ def get_local_groups_for_room(self, room_id):
+ """Get all of the local group that contain a given room
+ Args:
+ room_id (str): The ID of a room
+ Returns:
+ Deferred[list[str]]: A twisted.Deferred containing a list of group ids
+ containing this room
+ """
+ return self.db.simple_select_onecol(
+ table="group_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="group_id",
+ desc="get_local_groups_for_room",
+ )
+
+ def get_users_for_summary_by_role(self, group_id, include_private=False):
+ """Get the users and roles that should be included in a summary request
+
+ Returns ([users], [roles])
+ """
+
+ def _get_users_for_summary_txn(txn):
+ keyvalues = {"group_id": group_id}
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ sql = """
+ SELECT user_id, is_public, role_id, user_order
+ FROM group_summary_users
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ users = [
+ {
+ "user_id": row[0],
+ "is_public": row[1],
+ "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
+ "order": row[3],
+ }
+ for row in txn
+ ]
+
+ sql = """
+ SELECT role_id, is_public, profile, role_order
+ FROM group_summary_roles
+ INNER JOIN group_roles USING (group_id, role_id)
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ roles = {
+ row[0]: {
+ "is_public": row[1],
+ "profile": json.loads(row[2]),
+ "order": row[3],
+ }
+ for row in txn
+ }
+
+ return users, roles
+
+ return self.db.runInteraction(
+ "get_users_for_summary_by_role", _get_users_for_summary_txn
+ )
+
+ def is_user_in_group(self, user_id, group_id):
+ return self.db.simple_select_one_onecol(
+ table="group_users",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="is_user_in_group",
+ ).addCallback(lambda r: bool(r))
+
+ def is_user_admin_in_group(self, group_id, user_id):
+ return self.db.simple_select_one_onecol(
+ table="group_users",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcol="is_admin",
+ allow_none=True,
+ desc="is_user_admin_in_group",
+ )
+
+ def is_user_invited_to_local_group(self, group_id, user_id):
+ """Has the group server invited a user?
+ """
+ return self.db.simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcol="user_id",
+ desc="is_user_invited_to_local_group",
+ allow_none=True,
+ )
+
+ def get_users_membership_info_in_group(self, group_id, user_id):
+ """Get a dict describing the membership of a user in a group.
+
+ Example if joined:
+
+ {
+ "membership": "join",
+ "is_public": True,
+ "is_privileged": False,
+ }
+
+ Returns an empty dict if the user is not join/invite/etc
+ """
+
+ def _get_users_membership_in_group_txn(txn):
+ row = self.db.simple_select_one_txn(
+ txn,
+ table="group_users",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcols=("is_admin", "is_public"),
+ allow_none=True,
+ )
+
+ if row:
+ return {
+ "membership": "join",
+ "is_public": row["is_public"],
+ "is_privileged": row["is_admin"],
+ }
+
+ row = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="group_invites",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ )
+
+ if row:
+ return {"membership": "invite"}
+
+ return {}
+
+ return self.db.runInteraction(
+ "get_users_membership_info_in_group", _get_users_membership_in_group_txn
+ )
+
+ def get_publicised_groups_for_user(self, user_id):
+ """Get all groups a user is publicising
+ """
+ return self.db.simple_select_onecol(
+ table="local_group_membership",
+ keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
+ retcol="group_id",
+ desc="get_publicised_groups_for_user",
+ )
+
+ def get_attestations_need_renewals(self, valid_until_ms):
+ """Get all attestations that need to be renewed until givent time
+ """
+
+ def _get_attestations_need_renewals_txn(txn):
+ sql = """
+ SELECT group_id, user_id FROM group_attestations_renewals
+ WHERE valid_until_ms <= ?
+ """
+ txn.execute(sql, (valid_until_ms,))
+ return self.db.cursor_to_dict(txn)
+
+ return self.db.runInteraction(
+ "get_attestations_need_renewals", _get_attestations_need_renewals_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_remote_attestation(self, group_id, user_id):
+ """Get the attestation that proves the remote agrees that the user is
+ in the group.
+ """
+ row = yield self.db.simple_select_one(
+ table="group_attestations_remote",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcols=("valid_until_ms", "attestation_json"),
+ desc="get_remote_attestation",
+ allow_none=True,
+ )
+
+ now = int(self._clock.time_msec())
+ if row and now < row["valid_until_ms"]:
+ return json.loads(row["attestation_json"])
+
+ return None
+
+ def get_joined_groups(self, user_id):
+ return self.db.simple_select_onecol(
+ table="local_group_membership",
+ keyvalues={"user_id": user_id, "membership": "join"},
+ retcol="group_id",
+ desc="get_joined_groups",
+ )
+
+ def get_all_groups_for_user(self, user_id, now_token):
+ def _get_all_groups_for_user_txn(txn):
+ sql = """
+ SELECT group_id, type, membership, u.content
+ FROM local_group_updates AS u
+ INNER JOIN local_group_membership USING (group_id, user_id)
+ WHERE user_id = ? AND membership != 'leave'
+ AND stream_id <= ?
+ """
+ txn.execute(sql, (user_id, now_token))
+ return [
+ {
+ "group_id": row[0],
+ "type": row[1],
+ "membership": row[2],
+ "content": json.loads(row[3]),
+ }
+ for row in txn
+ ]
+
+ return self.db.runInteraction(
+ "get_all_groups_for_user", _get_all_groups_for_user_txn
+ )
+
+ def get_groups_changes_for_user(self, user_id, from_token, to_token):
+ from_token = int(from_token)
+ has_changed = self._group_updates_stream_cache.has_entity_changed(
+ user_id, from_token
+ )
+ if not has_changed:
+ return defer.succeed([])
+
+ def _get_groups_changes_for_user_txn(txn):
+ sql = """
+ SELECT group_id, membership, type, u.content
+ FROM local_group_updates AS u
+ INNER JOIN local_group_membership USING (group_id, user_id)
+ WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+ """
+ txn.execute(sql, (user_id, from_token, to_token))
+ return [
+ {
+ "group_id": group_id,
+ "membership": membership,
+ "type": gtype,
+ "content": json.loads(content_json),
+ }
+ for group_id, membership, gtype, content_json in txn
+ ]
+
+ return self.db.runInteraction(
+ "get_groups_changes_for_user", _get_groups_changes_for_user_txn
+ )
+
+ def get_all_groups_changes(self, from_token, to_token, limit):
+ from_token = int(from_token)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(
+ from_token
+ )
+ if not has_changed:
+ return defer.succeed([])
+
+ def _get_all_groups_changes_txn(txn):
+ sql = """
+ SELECT stream_id, group_id, user_id, type, content
+ FROM local_group_updates
+ WHERE ? < stream_id AND stream_id <= ?
+ LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit))
+ return [
+ (stream_id, group_id, user_id, gtype, json.loads(content_json))
+ for stream_id, group_id, user_id, gtype, content_json in txn
+ ]
+
+ return self.db.runInteraction(
+ "get_all_groups_changes", _get_all_groups_changes_txn
+ )
+
+
+class GroupServerStore(GroupServerWorkerStore):
+ def set_group_join_policy(self, group_id, join_policy):
+ """Set the join policy of a group.
+
+ join_policy can be one of:
+ * "invite"
+ * "open"
+ """
+ return self.db.simple_update_one(
+ table="groups",
+ keyvalues={"group_id": group_id},
+ updatevalues={"join_policy": join_policy},
+ desc="set_group_join_policy",
+ )
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
- return self.runInteraction(
+ return self.db.runInteraction(
"add_room_to_summary",
self._add_room_to_summary_txn,
group_id,
@@ -181,7 +528,7 @@ class GroupServerStore(SQLBaseStore):
an order of 1 will put the room first. Otherwise, the room gets
added to the end.
"""
- room_in_group = self._simple_select_one_onecol_txn(
+ room_in_group = self.db.simple_select_one_onecol_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
@@ -194,7 +541,7 @@ class GroupServerStore(SQLBaseStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
else:
- cat_exists = self._simple_select_one_onecol_txn(
+ cat_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -205,7 +552,7 @@ class GroupServerStore(SQLBaseStore):
raise SynapseError(400, "Category doesn't exist")
# TODO: Check category is part of summary already
- cat_exists = self._simple_select_one_onecol_txn(
+ cat_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_summary_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -225,7 +572,7 @@ class GroupServerStore(SQLBaseStore):
(group_id, category_id, group_id, category_id),
)
- existing = self._simple_select_one_txn(
+ existing = self.db.simple_select_one_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -250,7 +597,7 @@ class GroupServerStore(SQLBaseStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
- order, = txn.fetchone()
+ (order,) = txn.fetchone()
if existing:
to_update = {}
@@ -258,7 +605,7 @@ class GroupServerStore(SQLBaseStore):
to_update["room_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self._simple_update_txn(
+ self.db.simple_update_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -272,7 +619,7 @@ class GroupServerStore(SQLBaseStore):
if is_public is None:
is_public = True
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_summary_rooms",
values={
@@ -288,7 +635,7 @@ class GroupServerStore(SQLBaseStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
- return self._simple_delete(
+ return self.db.simple_delete(
table="group_summary_rooms",
keyvalues={
"group_id": group_id,
@@ -298,38 +645,6 @@ class GroupServerStore(SQLBaseStore):
desc="remove_room_from_summary",
)
- @defer.inlineCallbacks
- def get_group_categories(self, group_id):
- rows = yield self._simple_select_list(
- table="group_room_categories",
- keyvalues={"group_id": group_id},
- retcols=("category_id", "is_public", "profile"),
- desc="get_group_categories",
- )
-
- defer.returnValue(
- {
- row["category_id"]: {
- "is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
- }
- for row in rows
- }
- )
-
- @defer.inlineCallbacks
- def get_group_category(self, group_id, category_id):
- category = yield self._simple_select_one(
- table="group_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- retcols=("is_public", "profile"),
- desc="get_group_category",
- )
-
- category["profile"] = json.loads(category["profile"])
-
- defer.returnValue(category)
-
def upsert_group_category(self, group_id, category_id, profile, is_public):
"""Add/update room category for group
"""
@@ -346,7 +661,7 @@ class GroupServerStore(SQLBaseStore):
else:
update_values["is_public"] = is_public
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
@@ -355,44 +670,12 @@ class GroupServerStore(SQLBaseStore):
)
def remove_group_category(self, group_id, category_id):
- return self._simple_delete(
+ return self.db.simple_delete(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
desc="remove_group_category",
)
- @defer.inlineCallbacks
- def get_group_roles(self, group_id):
- rows = yield self._simple_select_list(
- table="group_roles",
- keyvalues={"group_id": group_id},
- retcols=("role_id", "is_public", "profile"),
- desc="get_group_roles",
- )
-
- defer.returnValue(
- {
- row["role_id"]: {
- "is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
- }
- for row in rows
- }
- )
-
- @defer.inlineCallbacks
- def get_group_role(self, group_id, role_id):
- role = yield self._simple_select_one(
- table="group_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- retcols=("is_public", "profile"),
- desc="get_group_role",
- )
-
- role["profile"] = json.loads(role["profile"])
-
- defer.returnValue(role)
-
def upsert_group_role(self, group_id, role_id, profile, is_public):
"""Add/remove user role
"""
@@ -409,7 +692,7 @@ class GroupServerStore(SQLBaseStore):
else:
update_values["is_public"] = is_public
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
@@ -418,14 +701,14 @@ class GroupServerStore(SQLBaseStore):
)
def remove_group_role(self, group_id, role_id):
- return self._simple_delete(
+ return self.db.simple_delete(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
desc="remove_group_role",
)
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
- return self.runInteraction(
+ return self.db.runInteraction(
"add_user_to_summary",
self._add_user_to_summary_txn,
group_id,
@@ -449,7 +732,7 @@ class GroupServerStore(SQLBaseStore):
an order of 1 will put the user first. Otherwise, the user gets
added to the end.
"""
- user_in_group = self._simple_select_one_onecol_txn(
+ user_in_group = self.db.simple_select_one_onecol_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -462,7 +745,7 @@ class GroupServerStore(SQLBaseStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
else:
- role_exists = self._simple_select_one_onecol_txn(
+ role_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -473,7 +756,7 @@ class GroupServerStore(SQLBaseStore):
raise SynapseError(400, "Role doesn't exist")
# TODO: Check role is part of the summary already
- role_exists = self._simple_select_one_onecol_txn(
+ role_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_summary_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -493,7 +776,7 @@ class GroupServerStore(SQLBaseStore):
(group_id, role_id, group_id, role_id),
)
- existing = self._simple_select_one_txn(
+ existing = self.db.simple_select_one_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@@ -514,7 +797,7 @@ class GroupServerStore(SQLBaseStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
- order, = txn.fetchone()
+ (order,) = txn.fetchone()
if existing:
to_update = {}
@@ -522,7 +805,7 @@ class GroupServerStore(SQLBaseStore):
to_update["user_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self._simple_update_txn(
+ self.db.simple_update_txn(
txn,
table="group_summary_users",
keyvalues={
@@ -536,7 +819,7 @@ class GroupServerStore(SQLBaseStore):
if is_public is None:
is_public = True
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_summary_users",
values={
@@ -552,158 +835,21 @@ class GroupServerStore(SQLBaseStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
- return self._simple_delete(
+ return self.db.simple_delete(
table="group_summary_users",
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
desc="remove_user_from_summary",
)
- def get_users_for_summary_by_role(self, group_id, include_private=False):
- """Get the users and roles that should be included in a summary request
-
- Returns ([users], [roles])
- """
-
- def _get_users_for_summary_txn(txn):
- keyvalues = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
-
- sql = """
- SELECT user_id, is_public, role_id, user_order
- FROM group_summary_users
- WHERE group_id = ?
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
- else:
- txn.execute(sql, (group_id,))
-
- users = [
- {
- "user_id": row[0],
- "is_public": row[1],
- "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
- "order": row[3],
- }
- for row in txn
- ]
-
- sql = """
- SELECT role_id, is_public, profile, role_order
- FROM group_summary_roles
- INNER JOIN group_roles USING (group_id, role_id)
- WHERE group_id = ?
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
- else:
- txn.execute(sql, (group_id,))
-
- roles = {
- row[0]: {
- "is_public": row[1],
- "profile": json.loads(row[2]),
- "order": row[3],
- }
- for row in txn
- }
-
- return users, roles
-
- return self.runInteraction(
- "get_users_for_summary_by_role", _get_users_for_summary_txn
- )
-
- def is_user_in_group(self, user_id, group_id):
- return self._simple_select_one_onecol(
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="is_user_in_group",
- ).addCallback(lambda r: bool(r))
-
- def is_user_admin_in_group(self, group_id, user_id):
- return self._simple_select_one_onecol(
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="is_admin",
- allow_none=True,
- desc="is_user_admin_in_group",
- )
-
def add_group_invite(self, group_id, user_id):
"""Record that the group server has invited a user
"""
- return self._simple_insert(
+ return self.db.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
)
- def is_user_invited_to_local_group(self, group_id, user_id):
- """Has the group server invited a user?
- """
- return self._simple_select_one_onecol(
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- desc="is_user_invited_to_local_group",
- allow_none=True,
- )
-
- def get_users_membership_info_in_group(self, group_id, user_id):
- """Get a dict describing the membership of a user in a group.
-
- Example if joined:
-
- {
- "membership": "join",
- "is_public": True,
- "is_privileged": False,
- }
-
- Returns an empty dict if the user is not join/invite/etc
- """
-
- def _get_users_membership_in_group_txn(txn):
- row = self._simple_select_one_txn(
- txn,
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcols=("is_admin", "is_public"),
- allow_none=True,
- )
-
- if row:
- return {
- "membership": "join",
- "is_public": row["is_public"],
- "is_privileged": row["is_admin"],
- }
-
- row = self._simple_select_one_onecol_txn(
- txn,
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- allow_none=True,
- )
-
- if row:
- return {"membership": "invite"}
-
- return {}
-
- return self.runInteraction(
- "get_users_membership_info_in_group", _get_users_membership_in_group_txn
- )
-
def add_user_to_group(
self,
group_id,
@@ -728,7 +874,7 @@ class GroupServerStore(SQLBaseStore):
"""
def _add_user_to_group_txn(txn):
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_users",
values={
@@ -739,14 +885,14 @@ class GroupServerStore(SQLBaseStore):
},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
if local_attestation:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -756,7 +902,7 @@ class GroupServerStore(SQLBaseStore):
},
)
if remote_attestation:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
@@ -767,49 +913,49 @@ class GroupServerStore(SQLBaseStore):
},
)
- return self.runInteraction("add_user_to_group", _add_user_to_group_txn)
+ return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn)
def remove_user_from_group(self, group_id, user_id):
def _remove_user_from_group_txn(txn):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn
)
def add_room_to_group(self, group_id, room_id, is_public):
- return self._simple_insert(
+ return self.db.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
)
def update_room_in_group_visibility(self, group_id, room_id, is_public):
- return self._simple_update(
+ return self.db.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
@@ -818,36 +964,26 @@ class GroupServerStore(SQLBaseStore):
def remove_room_from_group(self, group_id, room_id):
def _remove_room_from_group_txn(txn):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_summary_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn
)
- def get_publicised_groups_for_user(self, user_id):
- """Get all groups a user is publicising
- """
- return self._simple_select_onecol(
- table="local_group_membership",
- keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
- retcol="group_id",
- desc="get_publicised_groups_for_user",
- )
-
def update_group_publicity(self, group_id, user_id, publicise):
"""Update whether the user is publicising their membership of the group
"""
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
@@ -883,12 +1019,12 @@ class GroupServerStore(SQLBaseStore):
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="local_group_membership",
values={
@@ -901,7 +1037,7 @@ class GroupServerStore(SQLBaseStore):
},
)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="local_group_updates",
values={
@@ -920,7 +1056,7 @@ class GroupServerStore(SQLBaseStore):
if membership == "join":
if local_attestation:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -930,7 +1066,7 @@ class GroupServerStore(SQLBaseStore):
},
)
if remote_attestation:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
@@ -941,12 +1077,12 @@ class GroupServerStore(SQLBaseStore):
},
)
else:
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -955,18 +1091,18 @@ class GroupServerStore(SQLBaseStore):
return next_id
with self._group_updates_id_gen.get_next() as next_id:
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
next_id,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
):
- yield self._simple_insert(
+ yield self.db.simple_insert(
table="groups",
values={
"group_id": group_id,
@@ -981,33 +1117,17 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def update_group_profile(self, group_id, profile):
- yield self._simple_update_one(
+ yield self.db.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues=profile,
desc="update_group_profile",
)
- def get_attestations_need_renewals(self, valid_until_ms):
- """Get all attestations that need to be renewed until givent time
- """
-
- def _get_attestations_need_renewals_txn(txn):
- sql = """
- SELECT group_id, user_id FROM group_attestations_renewals
- WHERE valid_until_ms <= ?
- """
- txn.execute(sql, (valid_until_ms,))
- return self.cursor_to_dict(txn)
-
- return self.runInteraction(
- "get_attestations_need_renewals", _get_attestations_need_renewals_txn
- )
-
def update_attestation_renewal(self, group_id, user_id, attestation):
"""Update an attestation that we have renewed
"""
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
@@ -1017,7 +1137,7 @@ class GroupServerStore(SQLBaseStore):
def update_remote_attestion(self, group_id, user_id, attestation):
"""Update an attestation that a remote has renewed
"""
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
@@ -1036,118 +1156,12 @@ class GroupServerStore(SQLBaseStore):
group_id (str)
user_id (str)
"""
- return self._simple_delete(
+ return self.db.simple_delete(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
desc="remove_attestation_renewal",
)
- @defer.inlineCallbacks
- def get_remote_attestation(self, group_id, user_id):
- """Get the attestation that proves the remote agrees that the user is
- in the group.
- """
- row = yield self._simple_select_one(
- table="group_attestations_remote",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcols=("valid_until_ms", "attestation_json"),
- desc="get_remote_attestation",
- allow_none=True,
- )
-
- now = int(self._clock.time_msec())
- if row and now < row["valid_until_ms"]:
- defer.returnValue(json.loads(row["attestation_json"]))
-
- defer.returnValue(None)
-
- def get_joined_groups(self, user_id):
- return self._simple_select_onecol(
- table="local_group_membership",
- keyvalues={"user_id": user_id, "membership": "join"},
- retcol="group_id",
- desc="get_joined_groups",
- )
-
- def get_all_groups_for_user(self, user_id, now_token):
- def _get_all_groups_for_user_txn(txn):
- sql = """
- SELECT group_id, type, membership, u.content
- FROM local_group_updates AS u
- INNER JOIN local_group_membership USING (group_id, user_id)
- WHERE user_id = ? AND membership != 'leave'
- AND stream_id <= ?
- """
- txn.execute(sql, (user_id, now_token))
- return [
- {
- "group_id": row[0],
- "type": row[1],
- "membership": row[2],
- "content": json.loads(row[3]),
- }
- for row in txn
- ]
-
- return self.runInteraction(
- "get_all_groups_for_user", _get_all_groups_for_user_txn
- )
-
- def get_groups_changes_for_user(self, user_id, from_token, to_token):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_entity_changed(
- user_id, from_token
- )
- if not has_changed:
- return []
-
- def _get_groups_changes_for_user_txn(txn):
- sql = """
- SELECT group_id, membership, type, u.content
- FROM local_group_updates AS u
- INNER JOIN local_group_membership USING (group_id, user_id)
- WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
- """
- txn.execute(sql, (user_id, from_token, to_token))
- return [
- {
- "group_id": group_id,
- "membership": membership,
- "type": gtype,
- "content": json.loads(content_json),
- }
- for group_id, membership, gtype, content_json in txn
- ]
-
- return self.runInteraction(
- "get_groups_changes_for_user", _get_groups_changes_for_user_txn
- )
-
- def get_all_groups_changes(self, from_token, to_token, limit):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(
- from_token
- )
- if not has_changed:
- return []
-
- def _get_all_groups_changes_txn(txn):
- sql = """
- SELECT stream_id, group_id, user_id, type, content
- FROM local_group_updates
- WHERE ? < stream_id AND stream_id <= ?
- LIMIT ?
- """
- txn.execute(sql, (from_token, to_token, limit))
- return [
- (stream_id, group_id, user_id, gtype, json.loads(content_json))
- for stream_id, group_id, user_id, gtype, content_json in txn
- ]
-
- return self.runInteraction(
- "get_all_groups_changes", _get_all_groups_changes_txn
- )
-
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
@@ -1178,12 +1192,8 @@ class GroupServerStore(SQLBaseStore):
]
for table in tables:
- self._simple_delete_txn(
- txn,
- table=table,
- keyvalues={"group_id": group_id},
+ self.db.simple_delete_txn(
+ txn, table=table, keyvalues={"group_id": group_id}
)
- return self.runInteraction(
- "delete_group", _delete_group_txn
- )
+ return self.db.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
new file mode 100644
index 0000000000..ba89c68c9f
--- /dev/null
+++ b/synapse/storage/data_stores/main/keys.py
@@ -0,0 +1,214 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+
+import six
+
+from signedjson.key import decode_verify_key_bytes
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.keys import FetchKeyResult
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
+
+logger = logging.getLogger(__name__)
+
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+ db_binary_type = six.moves.builtins.buffer
+else:
+ db_binary_type = memoryview
+
+
+class KeyStore(SQLBaseStore):
+ """Persistence for signature verification keys
+ """
+
+ @cached()
+ def _get_server_verify_key(self, server_name_and_key_id):
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
+ )
+ def get_server_verify_keys(self, server_name_and_key_ids):
+ """
+ Args:
+ server_name_and_key_ids (iterable[Tuple[str, str]]):
+ iterable of (server_name, key-id) tuples to fetch keys for
+
+ Returns:
+ Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
+ map from (server_name, key_id) -> FetchKeyResult, or None if the key is
+ unknown
+ """
+ keys = {}
+
+ def _get_keys(txn, batch):
+ """Processes a batch of keys to fetch, and adds the result to `keys`."""
+
+ # batch_iter always returns tuples so it's safe to do len(batch)
+ sql = (
+ "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
+ "FROM server_signature_keys WHERE 1=0"
+ ) + " OR (server_name=? AND key_id=?)" * len(batch)
+
+ txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
+
+ for row in txn:
+ server_name, key_id, key_bytes, ts_valid_until_ms = row
+
+ if ts_valid_until_ms is None:
+ # Old keys may be stored with a ts_valid_until_ms of null,
+ # in which case we treat this as if it was set to `0`, i.e.
+ # it won't match key requests that define a minimum
+ # `ts_valid_until_ms`.
+ ts_valid_until_ms = 0
+
+ res = FetchKeyResult(
+ verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
+ valid_until_ts=ts_valid_until_ms,
+ )
+ keys[(server_name, key_id)] = res
+
+ def _txn(txn):
+ for batch in batch_iter(server_name_and_key_ids, 50):
+ _get_keys(txn, batch)
+ return keys
+
+ return self.db.runInteraction("get_server_verify_keys", _txn)
+
+ def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ """Stores NACL verification keys for remote servers.
+ Args:
+ from_server (str): Where the verification keys were looked up
+ ts_added_ms (int): The time to record that the key was added
+ verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ keys to be stored. Each entry is a triplet of
+ (server_name, key_id, key).
+ """
+ key_values = []
+ value_values = []
+ invalidations = []
+ for server_name, key_id, fetch_result in verify_keys:
+ key_values.append((server_name, key_id))
+ value_values.append(
+ (
+ from_server,
+ ts_added_ms,
+ fetch_result.valid_until_ts,
+ db_binary_type(fetch_result.verify_key.encode()),
+ )
+ )
+ # invalidate takes a tuple corresponding to the params of
+ # _get_server_verify_key. _get_server_verify_key only takes one
+ # param, which is itself the 2-tuple (server_name, key_id).
+ invalidations.append((server_name, key_id))
+
+ def _invalidate(res):
+ f = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ f((i,))
+ return res
+
+ return self.db.runInteraction(
+ "store_server_verify_keys",
+ self.db.simple_upsert_many_txn,
+ table="server_signature_keys",
+ key_names=("server_name", "key_id"),
+ key_values=key_values,
+ value_names=(
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "verify_key",
+ ),
+ value_values=value_values,
+ ).addCallback(_invalidate)
+
+ def store_server_keys_json(
+ self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
+ ):
+ """Stores the JSON bytes for a set of keys from a server
+ The JSON should be signed by the originating server, the intermediate
+ server, and by this server. Updates the value for the
+ (server_name, key_id, from_server) triplet if one already existed.
+ Args:
+ server_name (str): The name of the server.
+ key_id (str): The identifer of the key this JSON is for.
+ from_server (str): The server this JSON was fetched from.
+ ts_now_ms (int): The time now in milliseconds.
+ ts_valid_until_ms (int): The time when this json stops being valid.
+ key_json (bytes): The encoded JSON.
+ """
+ return self.db.simple_upsert(
+ table="server_keys_json",
+ keyvalues={
+ "server_name": server_name,
+ "key_id": key_id,
+ "from_server": from_server,
+ },
+ values={
+ "server_name": server_name,
+ "key_id": key_id,
+ "from_server": from_server,
+ "ts_added_ms": ts_now_ms,
+ "ts_valid_until_ms": ts_expires_ms,
+ "key_json": db_binary_type(key_json_bytes),
+ },
+ desc="store_server_keys_json",
+ )
+
+ def get_server_keys_json(self, server_keys):
+ """Retrive the key json for a list of server_keys and key ids.
+ If no keys are found for a given server, key_id and source then
+ that server, key_id, and source triplet entry will be an empty list.
+ The JSON is returned as a byte array so that it can be efficiently
+ used in an HTTP response.
+ Args:
+ server_keys (list): List of (server_name, key_id, source) triplets.
+ Returns:
+ Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
+ Dict mapping (server_name, key_id, source) triplets to lists of dicts
+ """
+
+ def _get_server_keys_json_txn(txn):
+ results = {}
+ for server_name, key_id, from_server in server_keys:
+ keyvalues = {"server_name": server_name}
+ if key_id is not None:
+ keyvalues["key_id"] = key_id
+ if from_server is not None:
+ keyvalues["from_server"] = from_server
+ rows = self.db.simple_select_list_txn(
+ txn,
+ "server_keys_json",
+ keyvalues=keyvalues,
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ )
+ results[(server_name, key_id, from_server)] = rows
+ return results
+
+ return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 3ecf47e7a7..80ca36dedf 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -12,29 +12,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
-class MediaRepositoryStore(BackgroundUpdateStore):
- """Persistence for attachments and avatars"""
-
- def __init__(self, db_conn, hs):
- super(MediaRepositoryStore, self).__init__(db_conn, hs)
+class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(MediaRepositoryBackgroundUpdateStore, self).__init__(
+ database, db_conn, hs
+ )
- self.register_background_index_update(
- update_name='local_media_repository_url_idx',
- index_name='local_media_repository_url_idx',
- table='local_media_repository',
- columns=['created_ts'],
- where_clause='url_cache IS NOT NULL',
+ self.db.updates.register_background_index_update(
+ update_name="local_media_repository_url_idx",
+ index_name="local_media_repository_url_idx",
+ table="local_media_repository",
+ columns=["created_ts"],
+ where_clause="url_cache IS NOT NULL",
)
+
+class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
+ """Persistence for attachments and avatars"""
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
+
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
- return self._simple_select_one(
+ return self.db.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -59,7 +67,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
user_id,
url_cache=None,
):
- return self._simple_insert(
+ return self.db.simple_insert(
"local_media_repository",
{
"media_id": media_id,
@@ -108,23 +116,23 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return dict(
zip(
(
- 'response_code',
- 'etag',
- 'expires_ts',
- 'og',
- 'media_id',
- 'download_ts',
+ "response_code",
+ "etag",
+ "expires_ts",
+ "og",
+ "media_id",
+ "download_ts",
),
row,
)
)
- return self.runInteraction("get_url_cache", get_url_cache_txn)
+ return self.db.runInteraction("get_url_cache", get_url_cache_txn)
def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
- return self._simple_insert(
+ return self.db.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
@@ -139,7 +147,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
def get_local_media_thumbnails(self, media_id):
- return self._simple_select_list(
+ return self.db.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@@ -161,7 +169,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self._simple_insert(
+ return self.db.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
@@ -175,7 +183,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
def get_cached_remote_media(self, origin, media_id):
- return self._simple_select_one(
+ return self.db.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -200,7 +208,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
upload_name,
filesystem_id,
):
- return self._simple_insert(
+ return self.db.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
@@ -245,10 +253,12 @@ class MediaRepositoryStore(BackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
- return self.runInteraction("update_cached_last_access_time", update_cache_txn)
+ return self.db.runInteraction(
+ "update_cached_last_access_time", update_cache_txn
+ )
def get_remote_media_thumbnails(self, origin, media_id):
- return self._simple_select_list(
+ return self.db.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@@ -273,7 +283,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self._simple_insert(
+ return self.db.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
@@ -295,24 +305,24 @@ class MediaRepositoryStore(BackgroundUpdateStore):
" WHERE last_access_ts < ?"
)
- return self._execute(
- "get_remote_media_before", self.cursor_to_dict, sql, before_ts
+ return self.db.execute(
+ "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts
)
def delete_remote_media(self, media_origin, media_id):
def delete_remote_media_txn(txn):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
"remote_media_cache",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+ return self.db.runInteraction("delete_remote_media", delete_remote_media_txn)
def get_expired_url_cache(self, now_ts):
sql = (
@@ -326,18 +336,20 @@ class MediaRepositoryStore(BackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
- return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+ return self.db.runInteraction(
+ "get_expired_url_cache", _get_expired_url_cache_txn
+ )
def delete_url_cache(self, media_ids):
if len(media_ids) == 0:
return
- sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+ return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
def get_url_cache_media_before(self, before_ts):
sql = (
@@ -351,7 +363,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
@@ -360,14 +372,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return
def _delete_url_cache_media_txn(txn):
- sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 8aa8abc470..925bc5691b 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -16,10 +16,10 @@ import logging
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.util.caches.descriptors import cached
-from ._base import SQLBaseStore
-
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the monthly_active_user timestamp
@@ -27,15 +27,105 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
-class MonthlyActiveUsersStore(SQLBaseStore):
- def __init__(self, dbconn, hs):
- super(MonthlyActiveUsersStore, self).__init__(None, hs)
+class MonthlyActiveUsersWorkerStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
- self.reserved_users = ()
+
+ @cached(num_args=0)
+ def get_monthly_active_count(self):
+ """Generates current count of monthly active users
+
+ Returns:
+ Defered[int]: Number of current monthly active users
+ """
+
+ def _count_users(txn):
+ sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
+ txn.execute(sql)
+ (count,) = txn.fetchone()
+ return count
+
+ return self.db.runInteraction("count_users", _count_users)
+
+ @cached(num_args=0)
+ def get_monthly_active_count_by_service(self):
+ """Generates current count of monthly active users broken down by service.
+ A service is typically an appservice but also includes native matrix users.
+ Since the `monthly_active_users` table is populated from the `user_ips` table
+ `config.track_appservice_user_ips` must be set to `true` for this
+ method to return anything other than native matrix users.
+
+ Returns:
+ Deferred[dict]: dict that includes a mapping between app_service_id
+ and the number of occurrences.
+
+ """
+
+ def _count_users_by_service(txn):
+ sql = """
+ SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
+ FROM monthly_active_users
+ LEFT JOIN users ON monthly_active_users.user_id=users.name
+ GROUP BY appservice_id;
+ """
+
+ txn.execute(sql)
+ result = txn.fetchall()
+ return dict(result)
+
+ return self.db.runInteraction("count_users_by_service", _count_users_by_service)
+
+ @defer.inlineCallbacks
+ def get_registered_reserved_users(self):
+ """Of the reserved threepids defined in config, which are associated
+ with registered users?
+
+ Returns:
+ Defered[list]: Real reserved users
+ """
+ users = []
+
+ for tp in self.hs.config.mau_limits_reserved_threepids[
+ : self.hs.config.max_mau_value
+ ]:
+ user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ tp["medium"], tp["address"]
+ )
+ if user_id:
+ users.append(user_id)
+
+ return users
+
+ @cached(num_args=1)
+ def user_last_seen_monthly_active(self, user_id):
+ """
+ Checks if a given user is part of the monthly active user group
+ Arguments:
+ user_id (str): user to add/update
+ Return:
+ Deferred[int] : timestamp since last seen, None if never seen
+
+ """
+
+ return self.db.simple_select_one_onecol(
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ retcol="timestamp",
+ allow_none=True,
+ desc="user_last_seen_monthly_active",
+ )
+
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
+
# Do not add more reserved users than the total allowable number
- self._new_transaction(
- dbconn,
+ # cur = LoggingTransaction(
+ self.db.new_transaction(
+ db_conn,
"initialise_mau_threepids",
[],
[],
@@ -51,7 +141,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn (cursor):
threepids (list[dict]): List of threepid dicts to reserve
"""
- reserved_user_list = []
for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
@@ -59,11 +148,15 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if user_id:
is_support = self.is_support_user_txn(txn, user_id)
if not is_support:
- self.upsert_monthly_active_user_txn(txn, user_id)
- reserved_user_list.append(user_id)
+ # We do this manually here to avoid hitting #6791
+ self.db.simple_upsert_txn(
+ txn,
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ values={"timestamp": int(self._clock.time_msec())},
+ )
else:
logger.warning("mau limit reserved threepid %s not found in db" % tp)
- self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks
def reap_monthly_active_users(self):
@@ -74,8 +167,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Deferred[]
"""
- def _reap_users(txn):
- # Purge stale users
+ def _reap_users(txn, reserved_users):
+ """
+ Args:
+ reserved_users (tuple): reserved users to preserve
+ """
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
query_args = [thirty_days_ago]
@@ -83,20 +179,19 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
- if len(self.reserved_users) > 0:
+ if len(reserved_users) > 0:
# questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s'
- questionmarks = '?' * len(self.reserved_users)
+ question_marks = ",".join("?" * len(reserved_users))
- query_args.extend(self.reserved_users)
- sql = base_sql + """ AND user_id NOT IN ({})""".format(
- ','.join(questionmarks)
- )
+ query_args.extend(reserved_users)
+ sql = base_sql + " AND user_id NOT IN ({})".format(question_marks)
else:
sql = base_sql
txn.execute(sql, query_args)
+ max_mau_value = self.hs.config.max_mau_value
if self.hs.config.limit_usage_by_mau:
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
@@ -106,74 +201,64 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# While Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
- safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
- # Must be greater than zero for postgres
- safe_guard = safe_guard if safe_guard > 0 else 0
- query_args = [safe_guard]
-
- base_sql = """
- DELETE FROM monthly_active_users
- WHERE user_id NOT IN (
- SELECT user_id FROM monthly_active_users
- ORDER BY timestamp DESC
- LIMIT ?
+ if len(reserved_users) == 0:
+ sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ ORDER BY timestamp DESC
+ LIMIT ?
)
- """
+ """
+ txn.execute(sql, (max_mau_value,))
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
- if len(self.reserved_users) > 0:
- query_args.extend(self.reserved_users)
- sql = base_sql + """ AND user_id NOT IN ({})""".format(
- ','.join(questionmarks)
- )
else:
- sql = base_sql
- txn.execute(sql, query_args)
-
- yield self.runInteraction("reap_monthly_active_users", _reap_users)
- # It seems poor to invalidate the whole cache, Postgres supports
- # 'Returning' which would allow me to invalidate only the
- # specific users, but sqlite has no way to do this and instead
- # I would need to SELECT and the DELETE which without locking
- # is racy.
- # Have resolved to invalidate the whole cache for now and do
- # something about it if and when the perf becomes significant
- self.user_last_seen_monthly_active.invalidate_all()
- self.get_monthly_active_count.invalidate_all()
-
- @cached(num_args=0)
- def get_monthly_active_count(self):
- """Generates current count of monthly active users
-
- Returns:
- Defered[int]: Number of current monthly active users
- """
-
- def _count_users(txn):
- sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
-
- txn.execute(sql)
- count, = txn.fetchone()
- return count
-
- return self.runInteraction("count_users", _count_users)
+ # Must be >= 0 for postgres
+ num_of_non_reserved_users_to_remove = max(
+ max_mau_value - len(reserved_users), 0
+ )
- @defer.inlineCallbacks
- def get_registered_reserved_users_count(self):
- """Of the reserved threepids defined in config, how many are associated
- with registered users?
+ # It is important to filter reserved users twice to guard
+ # against the case where the reserved user is present in the
+ # SELECT, meaning that a legitmate mau is deleted.
+ sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ WHERE user_id NOT IN ({})
+ ORDER BY timestamp DESC
+ LIMIT ?
+ )
+ AND user_id NOT IN ({})
+ """.format(
+ question_marks, question_marks
+ )
- Returns:
- Defered[int]: Number of real reserved users
- """
- count = 0
- for tp in self.hs.config.mau_limits_reserved_threepids:
- user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- tp["medium"], tp["address"]
+ query_args = [
+ *reserved_users,
+ num_of_non_reserved_users_to_remove,
+ *reserved_users,
+ ]
+
+ txn.execute(sql, query_args)
+
+ # It seems poor to invalidate the whole cache, Postgres supports
+ # 'Returning' which would allow me to invalidate only the
+ # specific users, but sqlite has no way to do this and instead
+ # I would need to SELECT and the DELETE which without locking
+ # is racy.
+ # Have resolved to invalidate the whole cache for now and do
+ # something about it if and when the perf becomes significant
+ self._invalidate_all_cache_and_stream(
+ txn, self.user_last_seen_monthly_active
)
- if user_id:
- count = count + 1
- defer.returnValue(count)
+ self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+
+ reserved_users = yield self.get_registered_reserved_users()
+ yield self.db.runInteraction(
+ "reap_monthly_active_users", _reap_users, reserved_users
+ )
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
@@ -195,27 +280,13 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if is_support:
return
- yield self.runInteraction(
+ yield self.db.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
- user_in_mau = self.user_last_seen_monthly_active.cache.get(
- (user_id,), None, update_metrics=False
- )
- if user_in_mau is None:
- self.get_monthly_active_count.invalidate(())
-
- self.user_last_seen_monthly_active.invalidate((user_id,))
-
def upsert_monthly_active_user_txn(self, txn, user_id):
"""Updates or inserts monthly active user member
- Note that, after calling this method, it will generally be necessary
- to invalidate the caches on user_last_seen_monthly_active and
- get_monthly_active_count. We can't do that here, because we are running
- in a database thread rather than the main thread, and we can't call
- txn.call_after because txn may not be a LoggingTransaction.
-
We consciously do not call is_support_txn from this method because it
is not possible to cache the response. is_support_txn will be false in
almost all cases, so it seems reasonable to call it only for
@@ -239,33 +310,22 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
- is_insert = self._simple_upsert_txn(
+ is_insert = self.db.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
values={"timestamp": int(self._clock.time_msec())},
)
- return is_insert
-
- @cached(num_args=1)
- def user_last_seen_monthly_active(self, user_id):
- """
- Checks if a given user is part of the monthly active user group
- Arguments:
- user_id (str): user to add/update
- Return:
- Deferred[int] : timestamp since last seen, None if never seen
-
- """
-
- return self._simple_select_one_onecol(
- table="monthly_active_users",
- keyvalues={"user_id": user_id},
- retcol="timestamp",
- allow_none=True,
- desc="user_last_seen_monthly_active",
+ self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+ self._invalidate_cache_and_stream(
+ txn, self.get_monthly_active_count_by_service, ()
)
+ self._invalidate_cache_and_stream(
+ txn, self.user_last_seen_monthly_active, (user_id,)
+ )
+
+ return is_insert
@defer.inlineCallbacks
def populate_monthly_active_users(self, user_id):
diff --git a/synapse/storage/openid.py b/synapse/storage/data_stores/main/openid.py
index b3318045ee..cc21437e92 100644
--- a/synapse/storage/openid.py
+++ b/synapse/storage/data_stores/main/openid.py
@@ -1,9 +1,9 @@
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
- return self._simple_insert(
+ return self.db.simple_insert(
table="open_id_tokens",
values={
"token": token,
@@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
- return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
+ return self.db.runInteraction(
+ "get_user_id_for_token", get_user_id_for_token_txn
+ )
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
new file mode 100644
index 0000000000..604c8b7ddd
--- /dev/null
+++ b/synapse/storage/data_stores/main/presence.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.presence import UserPresenceState
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
+
+
+class PresenceStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def update_presence(self, presence_states):
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ len(presence_states)
+ )
+
+ with stream_ordering_manager as stream_orderings:
+ yield self.db.runInteraction(
+ "update_presence",
+ self._update_presence_txn,
+ stream_orderings,
+ presence_states,
+ )
+
+ return stream_orderings[-1], self._presence_id_gen.get_current_token()
+
+ def _update_presence_txn(self, txn, stream_orderings, presence_states):
+ for stream_id, state in zip(stream_orderings, presence_states):
+ txn.call_after(
+ self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
+ )
+ txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
+
+ # Actually insert new rows
+ self.db.simple_insert_many_txn(
+ txn,
+ table="presence_stream",
+ values=[
+ {
+ "stream_id": stream_id,
+ "user_id": state.user_id,
+ "state": state.state,
+ "last_active_ts": state.last_active_ts,
+ "last_federation_update_ts": state.last_federation_update_ts,
+ "last_user_sync_ts": state.last_user_sync_ts,
+ "status_msg": state.status_msg,
+ "currently_active": state.currently_active,
+ }
+ for state in presence_states
+ ],
+ )
+
+ # Delete old rows to stop database from getting really big
+ sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
+
+ for states in batch_iter(presence_states, 50):
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "user_id", [s.user_id for s in states]
+ )
+ txn.execute(sql + clause, [stream_id] + list(args))
+
+ def get_all_presence_updates(self, last_id, current_id):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_presence_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, state, last_active_ts,"
+ " last_federation_update_ts, last_user_sync_ts, status_msg,"
+ " currently_active"
+ " FROM presence_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ )
+ txn.execute(sql, (last_id, current_id))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_presence_updates", get_all_presence_updates_txn
+ )
+
+ @cached()
+ def _get_presence_for_user(self, user_id):
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_presence_for_user",
+ list_name="user_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
+ def get_presence_for_users(self, user_ids):
+ rows = yield self.db.simple_select_many_batch(
+ table="presence_stream",
+ column="user_id",
+ iterable=user_ids,
+ keyvalues={},
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ desc="get_presence_for_users",
+ )
+
+ for row in rows:
+ row["currently_active"] = bool(row["currently_active"])
+
+ return {row["user_id"]: UserPresenceState(**row) for row in rows}
+
+ def get_current_presence_token(self):
+ return self._presence_id_gen.get_current_token()
+
+ def allow_presence_visible(self, observed_localpart, observer_userid):
+ return self.db.simple_insert(
+ table="presence_allow_inbound",
+ values={
+ "observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid,
+ },
+ desc="allow_presence_visible",
+ or_ignore=True,
+ )
+
+ def disallow_presence_visible(self, observed_localpart, observer_userid):
+ return self.db.simple_delete_one(
+ table="presence_allow_inbound",
+ keyvalues={
+ "observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid,
+ },
+ desc="disallow_presence_visible",
+ )
diff --git a/synapse/storage/profile.py b/synapse/storage/data_stores/main/profile.py
index 38524f2545..2a97991d23 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -17,10 +17,9 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
-from synapse.storage.roommember import ProfileInfo
-
-from . import background_updates
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.roommember import ProfileInfo
+from synapse.util.caches.descriptors import cached
BATCH_SIZE = 100
@@ -29,7 +28,7 @@ class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_profileinfo(self, user_localpart):
try:
- profile = yield self._simple_select_one(
+ profile = yield self.db.simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
@@ -38,27 +37,26 @@ class ProfileWorkerStore(SQLBaseStore):
except StoreError as e:
if e.code == 404:
# no match
- defer.returnValue(ProfileInfo(None, None))
- return
+ return ProfileInfo(None, None)
else:
raise
- defer.returnValue(
- ProfileInfo(
- avatar_url=profile['avatar_url'], display_name=profile['displayname']
- )
+ return ProfileInfo(
+ avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
+ @cached(max_entries=5000)
def get_profile_displayname(self, user_localpart):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
+ @cached(max_entries=5000)
def get_profile_avatar_url(self, user_localpart):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
@@ -68,18 +66,15 @@ class ProfileWorkerStore(SQLBaseStore):
def get_latest_profile_replication_batch_number(self):
def f(txn):
txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
- rows = self.cursor_to_dict(txn)
- return rows[0]['maxbatch']
- return self.runInteraction(
- "get_latest_profile_replication_batch_number", f,
- )
+ rows = self.db.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return self.db.runInteraction("get_latest_profile_replication_batch_number", f)
def get_profile_batch(self, batchnum):
- return self._simple_select_list(
+ return self.db.simple_select_list(
table="profiles",
- keyvalues={
- "batch": batchnum,
- },
+ keyvalues={"batch": batchnum},
retcols=("user_id", "displayname", "avatar_url", "active"),
desc="get_profile_batch",
)
@@ -95,27 +90,29 @@ class ProfileWorkerStore(SQLBaseStore):
)
txn.execute(sql, (BATCH_SIZE,))
return txn.rowcount
- return self.runInteraction("assign_profile_batch", f)
+
+ return self.db.runInteraction("assign_profile_batch", f)
def get_replication_hosts(self):
def f(txn):
- txn.execute("SELECT host, last_synced_batch FROM profile_replication_status")
- rows = self.cursor_to_dict(txn)
- return {r['host']: r['last_synced_batch'] for r in rows}
- return self.runInteraction("get_replication_hosts", f)
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.db.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return self.db.runInteraction("get_replication_hosts", f)
def update_replication_batch_for_host(self, host, last_synced_batch):
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="profile_replication_status",
keyvalues={"host": host},
- values={
- "last_synced_batch": last_synced_batch,
- },
+ values={"last_synced_batch": last_synced_batch},
desc="update_replication_batch_for_host",
)
def get_from_remote_profile_cache(self, user_id):
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@@ -123,55 +120,57 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
+ def create_profile(self, user_localpart):
+ return self.db.simple_insert(
+ table="profiles", values={"user_id": user_localpart}, desc="create_profile"
+ )
+
def set_profile_displayname(self, user_localpart, new_displayname, batchnum):
- return self._simple_upsert(
+ # Invalidate the read cache for this user
+ self.get_profile_displayname.invalidate((user_localpart,))
+
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- values={
- "displayname": new_displayname,
- "batch": batchnum,
- },
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
- lock=False # we can do this because user_id has a unique index
+ lock=False, # we can do this because user_id has a unique index
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum):
- return self._simple_upsert(
+ # Invalidate the read cache for this user
+ self.get_profile_avatar_url.invalidate((user_localpart,))
+
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- values={
- "avatar_url": new_avatar_url,
- "batch": batchnum,
- },
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
- lock=False # we can do this because user_id has a unique index
+ lock=False, # we can do this because user_id has a unique index
)
def set_profile_active(self, user_localpart, active, hide, batchnum):
- values = {
- "active": int(active),
- "batch": batchnum,
- }
+ values = {"active": int(active), "batch": batchnum}
if not active and not hide:
# we are deactivating for real (not in hide mode)
# so clear the profile.
values["avatar_url"] = None
values["displayname"] = None
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
values=values,
desc="set_profile_active",
- lock=False # we can do this because user_id has a unique index
+ lock=False, # we can do this because user_id has a unique index
)
-class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore):
- def __init__(self, db_conn, hs):
+class ProfileStore(ProfileWorkerStore):
+ def __init__(self, database, db_conn, hs):
- super(ProfileStore, self).__init__(db_conn, hs)
+ super(ProfileStore, self).__init__(database, db_conn, hs)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"profile_replication_status_host_index",
index_name="profile_replication_status_idx",
table="profile_replication_status",
@@ -185,7 +184,7 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -197,7 +196,7 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self._simple_update(
+ return self.db.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -215,7 +214,7 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
"""
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
- yield self._simple_delete(
+ yield self.db.simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
@@ -234,9 +233,9 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
txn.execute(sql, (last_checked,))
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
@@ -245,7 +244,7 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
@@ -254,9 +253,9 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
)
if res:
- defer.returnValue(True)
+ return True
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",
@@ -265,4 +264,4 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
)
if res:
- defer.returnValue(True)
+ return True
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
new file mode 100644
index 0000000000..62ac88d9f2
--- /dev/null
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -0,0 +1,714 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import logging
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.push.baserules import list_with_base_rules
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.data_stores.main.pusher import PusherWorkerStore
+from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
+from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
+from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+logger = logging.getLogger(__name__)
+
+
+def _load_rules(rawrules, enabled_map):
+ ruleslist = []
+ for rawrule in rawrules:
+ rule = dict(rawrule)
+ rule["conditions"] = json.loads(rawrule["conditions"])
+ rule["actions"] = json.loads(rawrule["actions"])
+ ruleslist.append(rule)
+
+ # We're going to be mutating this a lot, so do a deep copy
+ rules = list(list_with_base_rules(ruleslist))
+
+ for i, rule in enumerate(rules):
+ rule_id = rule["rule_id"]
+ if rule_id in enabled_map:
+ if rule.get("enabled", True) != bool(enabled_map[rule_id]):
+ # Rules are cached across users.
+ rule = dict(rule)
+ rule["enabled"] = bool(enabled_map[rule_id])
+ rules[i] = rule
+
+ return rules
+
+
+class PushRulesWorkerStore(
+ ApplicationServiceWorkerStore,
+ ReceiptsWorkerStore,
+ PusherWorkerStore,
+ RoomMemberWorkerStore,
+ SQLBaseStore,
+):
+ """This is an abstract base class where subclasses must implement
+ `get_max_push_rules_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+
+ push_rules_prefill, push_rules_id = self.db.get_cache_dict(
+ db_conn,
+ "push_rules_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self.get_max_push_rules_stream_id(),
+ )
+
+ self.push_rules_stream_cache = StreamChangeCache(
+ "PushRulesStreamChangeCache",
+ push_rules_id,
+ prefilled_cache=push_rules_prefill,
+ )
+
+ @abc.abstractmethod
+ def get_max_push_rules_stream_id(self):
+ """Get the position of the push rules stream.
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
+ @cachedInlineCallbacks(max_entries=5000)
+ def get_push_rules_for_user(self, user_id):
+ rows = yield self.db.simple_select_list(
+ table="push_rules",
+ keyvalues={"user_name": user_id},
+ retcols=(
+ "user_name",
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ),
+ desc="get_push_rules_enabled_for_user",
+ )
+
+ rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+
+ enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+
+ rules = _load_rules(rows, enabled_map)
+
+ return rules
+
+ @cachedInlineCallbacks(max_entries=5000)
+ def get_push_rules_enabled_for_user(self, user_id):
+ results = yield self.db.simple_select_list(
+ table="push_rules_enable",
+ keyvalues={"user_name": user_id},
+ retcols=("user_name", "rule_id", "enabled"),
+ desc="get_push_rules_enabled_for_user",
+ )
+ return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
+
+ def have_push_rules_changed_for_user(self, user_id, last_id):
+ if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+ return defer.succeed(False)
+ else:
+
+ def have_push_rules_changed_txn(txn):
+ sql = (
+ "SELECT COUNT(stream_id) FROM push_rules_stream"
+ " WHERE user_id = ? AND ? < stream_id"
+ )
+ txn.execute(sql, (user_id, last_id))
+ (count,) = txn.fetchone()
+ return bool(count)
+
+ return self.db.runInteraction(
+ "have_push_rules_changed", have_push_rules_changed_txn
+ )
+
+ @cachedList(
+ cached_method_name="get_push_rules_for_user",
+ list_name="user_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
+ def bulk_get_push_rules(self, user_ids):
+ if not user_ids:
+ return {}
+
+ results = {user_id: [] for user_id in user_ids}
+
+ rows = yield self.db.simple_select_many_batch(
+ table="push_rules",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("*",),
+ desc="bulk_get_push_rules",
+ )
+
+ rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+
+ for row in rows:
+ results.setdefault(row["user_name"], []).append(row)
+
+ enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+
+ for user_id, rules in results.items():
+ results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
+
+ return results
+
+ @defer.inlineCallbacks
+ def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+ """Copy a single push rule from one room to another for a specific user.
+
+ Args:
+ new_room_id (str): ID of the new room.
+ user_id (str): ID of user the push rule belongs to.
+ rule (Dict): A push rule.
+ """
+ # Create new rule id
+ rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+ new_rule_id = rule_id_scope + "/" + new_room_id
+
+ # Change room id in each condition
+ for condition in rule.get("conditions", []):
+ if condition.get("key") == "room_id":
+ condition["pattern"] = new_room_id
+
+ # Add the rule for the new room
+ yield self.add_push_rule(
+ user_id=user_id,
+ rule_id=new_rule_id,
+ priority_class=rule["priority_class"],
+ conditions=rule["conditions"],
+ actions=rule["actions"],
+ )
+
+ @defer.inlineCallbacks
+ def copy_push_rules_from_room_to_room_for_user(
+ self, old_room_id, new_room_id, user_id
+ ):
+ """Copy all of the push rules from one room to another for a specific
+ user.
+
+ Args:
+ old_room_id (str): ID of the old room.
+ new_room_id (str): ID of the new room.
+ user_id (str): ID of user to copy push rules for.
+ """
+ # Retrieve push rules for this user
+ user_push_rules = yield self.get_push_rules_for_user(user_id)
+
+ # Get rules relating to the old room and copy them to the new room
+ for rule in user_push_rules:
+ conditions = rule.get("conditions", [])
+ if any(
+ (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
+ for c in conditions
+ ):
+ yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
+
+ @defer.inlineCallbacks
+ def bulk_get_push_rules_for_room(self, event, context):
+ state_group = context.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ current_state_ids = yield context.get_current_state_ids()
+ result = yield self._bulk_get_push_rules_for_room(
+ event.room_id, state_group, current_state_ids, event=event
+ )
+ return result
+
+ @cachedInlineCallbacks(num_args=2, cache_context=True)
+ def _bulk_get_push_rules_for_room(
+ self, room_id, state_group, current_state_ids, cache_context, event=None
+ ):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ # We also will want to generate notifs for other people in the room so
+ # their unread countss are correct in the event stream, but to avoid
+ # generating them for bot / AS users etc, we only do so for people who've
+ # sent a read receipt into the room.
+
+ users_in_room = yield self._get_joined_users_from_context(
+ room_id,
+ state_group,
+ current_state_ids,
+ on_invalidate=cache_context.invalidate,
+ event=event,
+ )
+
+ # We ignore app service users for now. This is so that we don't fill
+ # up the `get_if_users_have_pushers` cache with AS entries that we
+ # know don't have pushers, nor even read receipts.
+ local_users_in_room = {
+ u
+ for u in users_in_room
+ if self.hs.is_mine_id(u)
+ and not self.get_if_app_services_interested_in_user(u)
+ }
+
+ # users in the room who have pushers need to get push rules run because
+ # that's how their pushers work
+ if_users_with_pushers = yield self.get_if_users_have_pushers(
+ local_users_in_room, on_invalidate=cache_context.invalidate
+ )
+ user_ids = {
+ uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
+ }
+
+ users_with_receipts = yield self.get_users_with_read_receipts_in_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+
+ # any users with pushers must be ours: they have pushers
+ for uid in users_with_receipts:
+ if uid in local_users_in_room:
+ user_ids.add(uid)
+
+ rules_by_user = yield self.bulk_get_push_rules(
+ user_ids, on_invalidate=cache_context.invalidate
+ )
+
+ rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
+
+ return rules_by_user
+
+ @cachedList(
+ cached_method_name="get_push_rules_enabled_for_user",
+ list_name="user_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
+ def bulk_get_push_rules_enabled(self, user_ids):
+ if not user_ids:
+ return {}
+
+ results = {user_id: {} for user_id in user_ids}
+
+ rows = yield self.db.simple_select_many_batch(
+ table="push_rules_enable",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("user_name", "rule_id", "enabled"),
+ desc="bulk_get_push_rules_enabled",
+ )
+ for row in rows:
+ enabled = bool(row["enabled"])
+ results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
+ return results
+
+
+class PushRuleStore(PushRulesWorkerStore):
+ @defer.inlineCallbacks
+ def add_push_rule(
+ self,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions,
+ actions,
+ before=None,
+ after=None,
+ ):
+ conditions_json = json.dumps(conditions)
+ actions_json = json.dumps(actions)
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ if before or after:
+ yield self.db.runInteraction(
+ "_add_push_rule_relative_txn",
+ self._add_push_rule_relative_txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
+ before,
+ after,
+ )
+ else:
+ yield self.db.runInteraction(
+ "_add_push_rule_highest_priority_txn",
+ self._add_push_rule_highest_priority_txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
+ )
+
+ def _add_push_rule_relative_txn(
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
+ before,
+ after,
+ ):
+ # Lock the table since otherwise we'll have annoying races between the
+ # SELECT here and the UPSERT below.
+ self.database_engine.lock_table(txn, "push_rules")
+
+ relative_to_rule = before or after
+
+ res = self.db.simple_select_one_txn(
+ txn,
+ table="push_rules",
+ keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
+ retcols=["priority_class", "priority"],
+ allow_none=True,
+ )
+
+ if not res:
+ raise RuleNotFoundException(
+ "before/after rule not found: %s" % (relative_to_rule,)
+ )
+
+ base_priority_class = res["priority_class"]
+ base_rule_priority = res["priority"]
+
+ if base_priority_class != priority_class:
+ raise InconsistentRuleException(
+ "Given priority class does not match class of relative rule"
+ )
+
+ if before:
+ # Higher priority rules are executed first, So adding a rule before
+ # a rule means giving it a higher priority than that rule.
+ new_rule_priority = base_rule_priority + 1
+ else:
+ # We increment the priority of the existing rules to make space for
+ # the new rule. Therefore if we want this rule to appear after
+ # an existing rule we give it the priority of the existing rule,
+ # and then increment the priority of the existing rule.
+ new_rule_priority = base_rule_priority
+
+ sql = (
+ "UPDATE push_rules SET priority = priority + 1"
+ " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
+ )
+
+ txn.execute(sql, (user_id, priority_class, new_rule_priority))
+
+ self._upsert_push_rule_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ new_rule_priority,
+ conditions_json,
+ actions_json,
+ )
+
+ def _add_push_rule_highest_priority_txn(
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
+ ):
+ # Lock the table since otherwise we'll have annoying races between the
+ # SELECT here and the UPSERT below.
+ self.database_engine.lock_table(txn, "push_rules")
+
+ # find the highest priority rule in that class
+ sql = (
+ "SELECT COUNT(*), MAX(priority) FROM push_rules"
+ " WHERE user_name = ? and priority_class = ?"
+ )
+ txn.execute(sql, (user_id, priority_class))
+ res = txn.fetchall()
+ (how_many, highest_prio) = res[0]
+
+ new_prio = 0
+ if how_many > 0:
+ new_prio = highest_prio + 1
+
+ self._upsert_push_rule_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ new_prio,
+ conditions_json,
+ actions_json,
+ )
+
+ def _upsert_push_rule_txn(
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ priority,
+ conditions_json,
+ actions_json,
+ update_stream=True,
+ ):
+ """Specialised version of simple_upsert_txn that picks a push_rule_id
+ using the _push_rule_id_gen if it needs to insert the rule. It assumes
+ that the "push_rules" table is locked"""
+
+ sql = (
+ "UPDATE push_rules"
+ " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
+ " WHERE user_name = ? AND rule_id = ?"
+ )
+
+ txn.execute(
+ sql,
+ (priority_class, priority, conditions_json, actions_json, user_id, rule_id),
+ )
+
+ if txn.rowcount == 0:
+ # We didn't update a row with the given rule_id so insert one
+ push_rule_id = self._push_rule_id_gen.get_next()
+
+ self.db.simple_insert_txn(
+ txn,
+ table="push_rules",
+ values={
+ "id": push_rule_id,
+ "user_name": user_id,
+ "rule_id": rule_id,
+ "priority_class": priority_class,
+ "priority": priority,
+ "conditions": conditions_json,
+ "actions": actions_json,
+ },
+ )
+
+ if update_stream:
+ self._insert_push_rules_update_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ op="ADD",
+ data={
+ "priority_class": priority_class,
+ "priority": priority,
+ "conditions": conditions_json,
+ "actions": actions_json,
+ },
+ )
+
+ @defer.inlineCallbacks
+ def delete_push_rule(self, user_id, rule_id):
+ """
+ Delete a push rule. Args specify the row to be deleted and can be
+ any of the columns in the push_rule table, but below are the
+ standard ones
+
+ Args:
+ user_id (str): The matrix ID of the push rule owner
+ rule_id (str): The rule_id of the rule to be deleted
+ """
+
+ def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+ self.db.simple_delete_one_txn(
+ txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
+ )
+
+ self._insert_push_rules_update_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
+ )
+
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.db.runInteraction(
+ "delete_push_rule",
+ delete_push_rule_txn,
+ stream_id,
+ event_stream_ordering,
+ )
+
+ @defer.inlineCallbacks
+ def set_push_rule_enabled(self, user_id, rule_id, enabled):
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.db.runInteraction(
+ "_set_push_rule_enabled_txn",
+ self._set_push_rule_enabled_txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ enabled,
+ )
+
+ def _set_push_rule_enabled_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+ ):
+ new_id = self._push_rules_enable_id_gen.get_next()
+ self.db.simple_upsert_txn(
+ txn,
+ "push_rules_enable",
+ {"user_name": user_id, "rule_id": rule_id},
+ {"enabled": 1 if enabled else 0},
+ {"id": new_id},
+ )
+
+ self._insert_push_rules_update_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ op="ENABLE" if enabled else "DISABLE",
+ )
+
+ @defer.inlineCallbacks
+ def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+ actions_json = json.dumps(actions)
+
+ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+ if is_default_rule:
+ # Add a dummy rule to the rules table with the user specified
+ # actions.
+ priority_class = -1
+ priority = 1
+ self._upsert_push_rule_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ priority,
+ "[]",
+ actions_json,
+ update_stream=False,
+ )
+ else:
+ self.db.simple_update_one_txn(
+ txn,
+ "push_rules",
+ {"user_name": user_id, "rule_id": rule_id},
+ {"actions": actions_json},
+ )
+
+ self._insert_push_rules_update_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ op="ACTIONS",
+ data={"actions": actions_json},
+ )
+
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.db.runInteraction(
+ "set_push_rule_actions",
+ set_push_rule_actions_txn,
+ stream_id,
+ event_stream_ordering,
+ )
+
+ def _insert_push_rules_update_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
+ ):
+ values = {
+ "stream_id": stream_id,
+ "event_stream_ordering": event_stream_ordering,
+ "user_id": user_id,
+ "rule_id": rule_id,
+ "op": op,
+ }
+ if data is not None:
+ values.update(data)
+
+ self.db.simple_insert_txn(txn, "push_rules_stream", values=values)
+
+ txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
+ txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
+ txn.call_after(
+ self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
+ )
+
+ def get_all_push_rule_updates(self, last_id, current_id, limit):
+ """Get all the push rules changes that have happend on the server"""
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_push_rule_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
+ " op, priority_class, priority, conditions, actions"
+ " FROM push_rules_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_push_rule_updates", get_all_push_rule_updates_txn
+ )
+
+ def get_push_rules_stream_token(self):
+ """Get the position of the push rules stream.
+ Returns a pair of a stream id for the push_rules stream and the
+ room stream ordering it corresponds to."""
+ return self._push_rules_stream_id_gen.get_current_token()
+
+ def get_max_push_rules_stream_id(self):
+ return self.get_push_rules_stream_token()[0]
diff --git a/synapse/storage/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 1567e1df48..547b9d69cb 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,55 +15,45 @@
# limitations under the License.
import logging
-
-import six
+from typing import Iterable, Iterator
from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from ._base import SQLBaseStore
-
logger = logging.getLogger(__name__)
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
-
class PusherWorkerStore(SQLBaseStore):
- def _decode_pushers_rows(self, rows):
+ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+ """JSON-decode the data in the rows returned from the `pushers` table
+
+ Drops any rows whose data cannot be decoded
+ """
for r in rows:
- dataJson = r['data']
- r['data'] = None
+ dataJson = r["data"]
try:
- if isinstance(dataJson, db_binary_type):
- dataJson = str(dataJson).decode("UTF8")
-
- r['data'] = json.loads(dataJson)
+ r["data"] = json.loads(dataJson)
except Exception as e:
- logger.warn(
+ logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
- r['id'],
+ r["id"],
dataJson,
e.args[0],
)
- pass
-
- if isinstance(r['pushkey'], db_binary_type):
- r['pushkey'] = str(r['pushkey']).decode("UTF8")
+ continue
- return rows
+ yield r
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
- ret = yield self._simple_select_one_onecol(
+ ret = yield self.db.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
- defer.returnValue(ret is not None)
+ return ret is not None
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
@@ -73,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
- ret = yield self._simple_select_list(
+ ret = yield self.db.simple_select_list(
"pushers",
keyvalues,
[
@@ -95,18 +85,18 @@ class PusherWorkerStore(SQLBaseStore):
],
desc="get_pushers_by",
)
- defer.returnValue(self._decode_pushers_rows(ret))
+ return self._decode_pushers_rows(ret)
@defer.inlineCallbacks
def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
- rows = yield self.runInteraction("get_all_pushers", get_pushers)
- defer.returnValue(rows)
+ rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
+ return rows
def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id:
@@ -133,9 +123,9 @@ class PusherWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
deleted = txn.fetchall()
- return (updated, deleted)
+ return updated, deleted
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@@ -178,7 +168,7 @@ class PusherWorkerStore(SQLBaseStore):
return results
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
@@ -194,18 +184,96 @@ class PusherWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_if_users_have_pushers(self, user_ids):
- rows = yield self._simple_select_many_batch(
- table='pushers',
- column='user_name',
+ rows = yield self.db.simple_select_many_batch(
+ table="pushers",
+ column="user_name",
iterable=user_ids,
- retcols=['user_name'],
- desc='get_if_users_have_pushers',
+ retcols=["user_name"],
+ desc="get_if_users_have_pushers",
)
result = {user_id: False for user_id in user_ids}
- result.update({r['user_name']: True for r in rows})
+ result.update({r["user_name"]: True for r in rows})
+
+ return result
+
+ @defer.inlineCallbacks
+ def update_pusher_last_stream_ordering(
+ self, app_id, pushkey, user_id, last_stream_ordering
+ ):
+ yield self.db.simple_update_one(
+ "pushers",
+ {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ {"last_stream_ordering": last_stream_ordering},
+ desc="update_pusher_last_stream_ordering",
+ )
+
+ @defer.inlineCallbacks
+ def update_pusher_last_stream_ordering_and_success(
+ self, app_id, pushkey, user_id, last_stream_ordering, last_success
+ ):
+ """Update the last stream ordering position we've processed up to for
+ the given pusher.
- defer.returnValue(result)
+ Args:
+ app_id (str)
+ pushkey (str)
+ last_stream_ordering (int)
+ last_success (int)
+
+ Returns:
+ Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+ """
+ updated = yield self.db.simple_update(
+ table="pushers",
+ keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ updatevalues={
+ "last_stream_ordering": last_stream_ordering,
+ "last_success": last_success,
+ },
+ desc="update_pusher_last_stream_ordering_and_success",
+ )
+
+ return bool(updated)
+
+ @defer.inlineCallbacks
+ def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
+ yield self.db.simple_update(
+ table="pushers",
+ keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ updatevalues={"failing_since": failing_since},
+ desc="update_pusher_failing_since",
+ )
+
+ @defer.inlineCallbacks
+ def get_throttle_params_by_room(self, pusher_id):
+ res = yield self.db.simple_select_list(
+ "pusher_throttle",
+ {"pusher": pusher_id},
+ ["room_id", "last_sent_ts", "throttle_ms"],
+ desc="get_throttle_params_by_room",
+ )
+
+ params_by_room = {}
+ for row in res:
+ params_by_room[row["room_id"]] = {
+ "last_sent_ts": row["last_sent_ts"],
+ "throttle_ms": row["throttle_ms"],
+ }
+
+ return params_by_room
+
+ @defer.inlineCallbacks
+ def set_throttle_params(self, pusher_id, room_id, params):
+ # no need to lock because `pusher_throttle` has a primary key on
+ # (pusher, room_id) so simple_upsert will retry
+ yield self.db.simple_upsert(
+ "pusher_throttle",
+ {"pusher": pusher_id, "room_id": room_id},
+ params,
+ desc="set_throttle_params",
+ lock=False,
+ )
class PusherStore(PusherWorkerStore):
@@ -230,8 +298,8 @@ class PusherStore(PusherWorkerStore):
):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
- # (app_id, pushkey, user_name) so _simple_upsert will retry
- yield self._simple_upsert(
+ # (app_id, pushkey, user_name) so simple_upsert will retry
+ yield self.db.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -241,7 +309,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
- "data": encode_canonical_json(data),
+ "data": bytearray(encode_canonical_json(data)),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
@@ -256,7 +324,7 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
@@ -270,7 +338,7 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
- self._simple_delete_one_txn(
+ self.db.simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -279,7 +347,7 @@ class PusherStore(PusherWorkerStore):
# it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter.
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="deleted_pushers",
values={
@@ -291,68 +359,4 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
- yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
-
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering(
- self, app_id, pushkey, user_id, last_stream_ordering
- ):
- yield self._simple_update_one(
- "pushers",
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
- {'last_stream_ordering': last_stream_ordering},
- desc="update_pusher_last_stream_ordering",
- )
-
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering_and_success(
- self, app_id, pushkey, user_id, last_stream_ordering, last_success
- ):
- yield self._simple_update_one(
- "pushers",
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
- {
- 'last_stream_ordering': last_stream_ordering,
- 'last_success': last_success,
- },
- desc="update_pusher_last_stream_ordering_and_success",
- )
-
- @defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self._simple_update_one(
- "pushers",
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
- {'failing_since': failing_since},
- desc="update_pusher_failing_since",
- )
-
- @defer.inlineCallbacks
- def get_throttle_params_by_room(self, pusher_id):
- res = yield self._simple_select_list(
- "pusher_throttle",
- {"pusher": pusher_id},
- ["room_id", "last_sent_ts", "throttle_ms"],
- desc="get_throttle_params_by_room",
- )
-
- params_by_room = {}
- for row in res:
- params_by_room[row["room_id"]] = {
- "last_sent_ts": row["last_sent_ts"],
- "throttle_ms": row["throttle_ms"],
- }
-
- defer.returnValue(params_by_room)
-
- @defer.inlineCallbacks
- def set_throttle_params(self, pusher_id, room_id, params):
- # no need to lock because `pusher_throttle` has a primary key on
- # (pusher, room_id) so _simple_upsert will retry
- yield self._simple_upsert(
- "pusher_throttle",
- {"pusher": pusher_id, "room_id": room_id},
- params,
- desc="set_throttle_params",
- lock=False,
- )
+ yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
diff --git a/synapse/storage/receipts.py b/synapse/storage/data_stores/main/receipts.py
index a1647e50a1..0d932a0672 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -21,12 +21,12 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from ._base import SQLBaseStore
-from .util.id_generators import StreamIdGenerator
-
logger = logging.getLogger(__name__)
@@ -39,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, db_conn, hs):
- super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -58,11 +58,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
- defer.returnValue(set(r['user_id'] for r in receipts))
+ return {r["user_id"] for r in receipts}
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
- return self._simple_select_list(
+ return self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
@@ -71,7 +71,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@@ -85,14 +85,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self._simple_select_list(
+ rows = yield self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
desc="get_receipts_for_user",
)
- defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
+ return {row["room_id"]: row["event_id"] for row in rows}
@defer.inlineCallbacks
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
@@ -109,17 +109,15 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
- defer.returnValue(
- {
- row[0]: {
- "event_id": row[1],
- "topological_ordering": row[2],
- "stream_ordering": row[3],
- }
- for row in rows
+ rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
+ return {
+ row[0]: {
+ "event_id": row[1],
+ "topological_ordering": row[2],
+ "stream_ordering": row[3],
}
- )
+ for row in rows
+ }
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@@ -147,7 +145,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
room_ids, to_key, from_key=from_key
)
- defer.returnValue([ev for res in results.values() for ev in res])
+ return [ev for res in results.values() for ev in res]
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
@@ -190,14 +188,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (room_id, to_key))
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
return rows
- rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
+ rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
- defer.returnValue([])
+ return []
content = {}
for row in rows:
@@ -205,9 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
row["user_id"]
] = json.loads(row["data"])
- defer.returnValue(
- [{"type": "m.receipt", "room_id": room_id, "content": content}]
- )
+ return [{"type": "m.receipt", "room_id": room_id, "content": content}]
@cachedList(
cached_method_name="_get_linearized_receipts_for_room",
@@ -217,32 +213,36 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
- defer.returnValue({})
+ return {}
def f(txn):
if from_key:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
- ) % (",".join(["?"] * len(room_ids)))
- args = list(room_ids)
- args.extend([from_key, to_key])
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id > ? AND stream_id <= ? AND
+ """
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
- txn.execute(sql, args)
+ txn.execute(sql + clause, [from_key, to_key] + list(args))
else:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id IN (%s) AND stream_id <= ?"
- ) % (",".join(["?"] * len(room_ids)))
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id <= ? AND
+ """
- args = list(room_ids)
- args.append(to_key)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
- txn.execute(sql, args)
+ txn.execute(sql + clause, [to_key] + list(args))
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
+ txn_results = yield self.db.runInteraction(
+ "_get_linearized_receipts_for_rooms", f
+ )
results = {}
for row in txn_results:
@@ -264,7 +264,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
room_id: [results[room_id]] if room_id in results else []
for room_id in room_ids
}
- defer.returnValue(results)
+ return results
def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id:
@@ -283,9 +283,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return (r[0:5] + (json.loads(r[5]),) for r in txn)
+ return [r[0:5] + (json.loads(r[5]),) for r in txn]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
@@ -316,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
class ReceiptsStore(ReceiptsWorkerStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
- super(ReceiptsStore, self).__init__(db_conn, hs)
+ super(ReceiptsStore, self).__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
@@ -338,7 +338,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
- res = self._simple_select_one_txn(
+ res = self.db.simple_select_one_txn(
txn,
table="events",
retcols=["stream_ordering", "received_ts"],
@@ -391,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type),
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="receipts_linearized",
keyvalues={
@@ -401,7 +401,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
},
)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="receipts_linearized",
values={
@@ -437,26 +437,32 @@ class ReceiptsStore(ReceiptsWorkerStore):
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn):
- query = (
- "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
- " SELECT max(stream_ordering) WHERE event_id IN (%s)"
- ")"
- ) % (",".join(["?"] * len(event_ids)))
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "event_id", event_ids
+ )
+
+ sql = """
+ SELECT event_id WHERE room_id = ? AND stream_ordering IN (
+ SELECT max(stream_ordering) WHERE %s
+ )
+ """ % (
+ clause,
+ )
- txn.execute(query, [room_id] + event_ids)
+ txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
- linearized_event_id = yield self.runInteraction(
+ linearized_event_id = yield self.db.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
- event_ts = yield self.runInteraction(
+ event_ts = yield self.db.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -468,7 +474,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
)
if event_ts is None:
- defer.returnValue(None)
+ return None
now = self._clock.time_msec()
logger.debug(
@@ -482,10 +488,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
max_persisted_id = self._receipts_id_gen.get_current_token()
- defer.returnValue((stream_id, max_persisted_id))
+ return stream_id, max_persisted_id
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
- return self.runInteraction(
+ return self.db.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@@ -511,7 +517,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
@@ -520,7 +526,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
},
)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="receipts_graph",
values={
diff --git a/synapse/storage/registration.py b/synapse/storage/data_stores/main/registration.py
index 028848cf89..035fe348b0 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -19,14 +19,15 @@ import logging
import re
from six import iterkeys
-from six.moves import range
from twisted.internet import defer
+from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError, ThreepidValidationError
-from synapse.storage import background_updates
+from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -36,25 +37,28 @@ logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(RegistrationWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
self.config = hs.config
self.clock = hs.get_clock()
@cached()
def get_user_by_id(self, user_id):
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
"name",
"password_hash",
"is_guest",
+ "admin",
"consent_version",
"consent_server_notice_sent",
"appservice_id",
"creation_ts",
+ "user_type",
+ "deactivated",
],
allow_none=True,
desc="get_user_by_id",
@@ -74,12 +78,12 @@ class RegistrationWorkerStore(SQLBaseStore):
info = yield self.get_user_by_id(user_id)
if not info:
- defer.returnValue(False)
+ return False
now = self.clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
- defer.returnValue(is_trial)
+ return is_trial
@cached()
def get_user_by_access_token(self, token):
@@ -89,9 +93,10 @@ class RegistrationWorkerStore(SQLBaseStore):
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`.
+ including the keys `name`, `is_guest`, `device_id`, `token_id`,
+ `valid_until_ms`.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@@ -106,18 +111,19 @@ class RegistrationWorkerStore(SQLBaseStore):
otherwise int representation of the timestamp (as a number of
milliseconds since epoch).
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
desc="get_expiration_ts_for_user",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
- def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
- renewal_token=None):
+ def set_account_validity_for_user(
+ self, user_id, expiration_ts, email_sent, renewal_token=None
+ ):
"""Updates the account validity properties of the given account, with the
given values.
@@ -131,8 +137,9 @@ class RegistrationWorkerStore(SQLBaseStore):
renewal_token (str): Renewal token the user can use to extend the validity
of their account. Defaults to no token.
"""
+
def set_account_validity_for_user_txn(txn):
- self._simple_update_txn(
+ self.db.simple_update_txn(
txn=txn,
table="account_validity",
keyvalues={"user_id": user_id},
@@ -143,12 +150,11 @@ class RegistrationWorkerStore(SQLBaseStore):
},
)
self._invalidate_cache_and_stream(
- txn, self.get_expiration_ts_for_user, (user_id,),
+ txn, self.get_expiration_ts_for_user, (user_id,)
)
- yield self.runInteraction(
- "set_account_validity_for_user",
- set_account_validity_for_user_txn,
+ yield self.db.runInteraction(
+ "set_account_validity_for_user", set_account_validity_for_user_txn
)
@defer.inlineCallbacks
@@ -158,6 +164,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[list[str]]: List of expired user IDs
"""
+
def get_expired_users_txn(txn, now_ms):
sql = """
SELECT user_id from account_validity
@@ -167,10 +174,8 @@ class RegistrationWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return [row[0] for row in rows]
- res = yield self.runInteraction(
- "get_expired_users",
- get_expired_users_txn,
- self.clock.time_msec(),
+ res = yield self.db.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
)
defer.returnValue(res)
@@ -186,7 +191,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Raises:
StoreError: The provided token is already set for another user.
"""
- yield self._simple_update_one(
+ yield self.db.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
@@ -203,14 +208,14 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The ID of the user to which the token belongs.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
desc="get_user_from_renewal_token",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def get_renewal_token_for_user(self, user_id):
@@ -222,14 +227,14 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The renewal token associated with this user ID.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
desc="get_renewal_token_for_user",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def get_users_expiring_soon(self):
@@ -240,6 +245,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
"""
+
def select_users_txn(txn, now_ms, renew_at):
sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity"
@@ -247,15 +253,16 @@ class RegistrationWorkerStore(SQLBaseStore):
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"get_users_expiring_soon",
select_users_txn,
- self.clock.time_msec(), self.config.account_validity.renew_at,
+ self.clock.time_msec(),
+ self.config.account_validity.renew_at,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def set_renewal_mail_status(self, user_id, email_sent):
@@ -267,7 +274,7 @@ class RegistrationWorkerStore(SQLBaseStore):
email_sent (bool): Flag which indicates whether a renewal email has been sent
to this user.
"""
- yield self._simple_update_one(
+ yield self.db.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
@@ -282,7 +289,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Args:
user_id (str): ID of the user to remove from the account validity table.
"""
- yield self._simple_delete_one(
+ yield self.db.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
@@ -290,7 +297,15 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def is_server_admin(self, user):
- res = yield self._simple_select_one_onecol(
+ """Determines if a user is an admin of this homeserver.
+
+ Args:
+ user (UserID): user ID of the user to test
+
+ Returns (bool):
+ true iff the user is a server admin, false otherwise.
+ """
+ res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
@@ -298,25 +313,59 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="is_server_admin",
)
- defer.returnValue(res if res else False)
+ return bool(res) if res else False
+
+ def set_server_admin(self, user, admin):
+ """Sets whether a user is an admin of this homeserver.
+
+ Args:
+ user (UserID): user ID of the user to test
+ admin (bool): true iff the user is to be a server admin,
+ false otherwise.
+ """
+
+ def set_server_admin_txn(txn):
+ self.db.simple_update_one_txn(
+ txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user.to_string(),)
+ )
+
+ return self.db.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
- " access_tokens.device_id"
+ " access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
txn.execute(sql, (token,))
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]
return None
@cachedInlineCallbacks()
+ def is_real_user(self, user_id):
+ """Determines if the user is a real user, ie does not have a 'user_type'.
+
+ Args:
+ user_id (str): user id to test
+
+ Returns:
+ Deferred[bool]: True if user 'user_type' is null or empty string
+ """
+ res = yield self.db.runInteraction(
+ "is_real_user", self.is_real_user_txn, user_id
+ )
+ return res
+
+ @cachedInlineCallbacks()
def is_support_user(self, user_id):
"""Determines if the user is of type UserTypes.SUPPORT
@@ -326,13 +375,23 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT
"""
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
- defer.returnValue(res)
+ return res
+
+ def is_real_user_txn(self, txn, user_id):
+ res = self.db.simple_select_one_onecol_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="user_type",
+ allow_none=True,
+ )
+ return res is None
def is_support_user_txn(self, txn, user_id):
- res = self._simple_select_one_onecol_txn(
+ res = self.db.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -347,13 +406,31 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def f(txn):
- sql = (
- "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
- )
+ sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
txn.execute(sql, (user_id,))
return dict(txn)
- return self.runInteraction("get_users_by_id_case_insensitive", f)
+ return self.db.runInteraction("get_users_by_id_case_insensitive", f)
+
+ async def get_user_by_external_id(
+ self, auth_provider: str, external_id: str
+ ) -> str:
+ """Look up a user by their external auth id
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+
+ Returns:
+ str|None: the mxid of the user, or None if they are not known
+ """
+ return await self.db.simple_select_one_onecol(
+ table="user_external_ids",
+ keyvalues={"auth_provider": auth_provider, "external_id": external_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="get_user_by_external_id",
+ )
@defer.inlineCallbacks
def count_all_users(self):
@@ -361,13 +438,13 @@ class RegistrationWorkerStore(SQLBaseStore):
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
- ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("count_users", _count_users)
+ return ret
def count_daily_user_type(self):
"""
@@ -392,13 +469,13 @@ class RegistrationWorkerStore(SQLBaseStore):
WHERE creation_ts > ?
) AS t GROUP BY user_type
"""
- results = {'native': 0, 'guest': 0, 'bridged': 0}
+ results = {"native": 0, "guest": 0, "bridged": 0}
txn.execute(sql, (yesterday,))
for row in txn:
results[row[0]] = row[1]
return results
- return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+ return self.db.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
@@ -409,63 +486,61 @@ class RegistrationWorkerStore(SQLBaseStore):
WHERE appservice_id IS NULL
"""
)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("count_users", _count_users)
+ return ret
+
+ @defer.inlineCallbacks
+ def count_real_users(self):
+ """Counts all users without a special user_type registered on the homeserver."""
+
+ def _count_users(txn):
+ txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
+ rows = self.db.cursor_to_dict(txn)
+ if rows:
+ return rows[0]["users"]
+ return 0
+
+ ret = yield self.db.runInteraction("count_real_users", _count_users)
+ return ret
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
- Generated user IDs are integers, and we aim for them to be as small as
- we can. Unfortunately, it's possible some of them are already taken by
- existing users, and there may be gaps in the already taken range. This
- function returns the start of the first allocatable gap. This is to
- avoid the case of ID 10000000 being pre-allocated, so us wasting the
- first (and shortest) many generated user IDs.
+ Generated user IDs are integers, so we find the largest integer user ID
+ already taken and return that plus one.
"""
def _find_next_generated_user_id(txn):
- txn.execute("SELECT name FROM users")
+ # We bound between '@0' and '@a' to avoid pulling the entire table
+ # out.
+ txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):")
- found = set()
+ max_found = 0
for (user_id,) in txn:
match = regex.search(user_id)
if match:
- found.add(int(match.group(1)))
- for i in range(len(found) + 1):
- if i not in found:
- return i
+ max_found = max(int(match.group(1)), max_found)
+
+ return max_found + 1
- defer.returnValue(
+ return (
(
- yield self.runInteraction(
+ yield self.db.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
)
)
)
@defer.inlineCallbacks
- def get_3pid_guest_access_token(self, medium, address):
- ret = yield self._simple_select_one(
- "threepid_guest_access_tokens",
- {"medium": medium, "address": address},
- ["guest_access_token"],
- True,
- 'get_3pid_guest_access_token',
- )
- if ret:
- defer.returnValue(ret["guest_access_token"])
- defer.returnValue(None)
-
- @defer.inlineCallbacks
- def get_user_id_by_threepid(self, medium, address, require_verified=False):
+ def get_user_id_by_threepid(self, medium, address):
"""Returns user id from threepid
Args:
@@ -475,10 +550,10 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
- user_id = yield self.runInteraction(
+ user_id = yield self.db.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
- defer.returnValue(user_id)
+ return user_id
def get_user_id_by_threepid_txn(self, txn, medium, address):
"""Returns user id from threepid
@@ -491,20 +566,20 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
- ret = self._simple_select_one_txn(
+ ret = self.db.simple_select_one_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
- ['user_id'],
+ ["user_id"],
True,
)
if ret:
- return ret['user_id']
+ return ret["user_id"]
return None
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self._simple_upsert(
+ yield self.db.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
@@ -512,18 +587,31 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
- ret = yield self._simple_select_list(
+ ret = yield self.db.simple_select_list(
"user_threepids",
{"user_id": user_id},
- ['medium', 'address', 'validated_at', 'added_at'],
- 'user_get_threepids',
+ ["medium", "address", "validated_at", "added_at"],
+ "user_get_threepids",
)
- defer.returnValue(ret)
+ return ret
def user_delete_threepid(self, user_id, medium, address):
- return self._simple_delete(
+ return self.db.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
+ desc="user_delete_threepid",
+ )
+
+ def user_delete_threepids(self, user_id: str):
+ """Delete all threepid this user has bound
+
+ Args:
+ user_id: The user id to delete all threepids of
+
+ """
+ return self.db.simple_delete(
+ "user_threepids",
+ keyvalues={"user_id": user_id},
desc="user_delete_threepids",
)
@@ -543,7 +631,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
# We need to use an upsert, in case they user had already bound the
# threepid
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -556,6 +644,26 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="add_user_bound_threepid",
)
+ def user_get_bound_threepids(self, user_id):
+ """Get the threepids that a user has bound to an identity server through the homeserver
+ The homeserver remembers where binds to an identity server occurred. Using this
+ method can retrieve those threepids.
+
+ Args:
+ user_id (str): The ID of the user to retrieve threepids for
+
+ Returns:
+ Deferred[list[dict]]: List of dictionaries containing the following:
+ medium (str): The medium of the threepid (e.g "email")
+ address (str): The address of the threepid (e.g "bob@example.com")
+ """
+ return self.db.simple_select_list(
+ table="user_threepid_id_server",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address"],
+ desc="user_get_bound_threepids",
+ )
+
def remove_user_bound_threepid(self, user_id, medium, address, id_server):
"""The server proxied an unbind request to the given identity server on
behalf of the given user, so we remove the mapping of threepid to
@@ -570,7 +678,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred
"""
- return self._simple_delete(
+ return self.db.simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -593,69 +701,170 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[list[str]]: Resolves to a list of identity servers
"""
- return self._simple_select_onecol(
+ return self.db.simple_select_onecol(
table="user_threepid_id_server",
- keyvalues={
- "user_id": user_id,
- "medium": medium,
- "address": address,
- },
+ keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
desc="get_id_servers_user_bound",
)
+ @cachedInlineCallbacks()
+ def get_user_deactivated_status(self, user_id):
+ """Retrieve the value for the `deactivated` property for the provided user.
-class RegistrationStore(
- RegistrationWorkerStore, background_updates.BackgroundUpdateStore
-):
- def __init__(self, db_conn, hs):
- super(RegistrationStore, self).__init__(db_conn, hs)
+ Args:
+ user_id (str): The ID of the user to retrieve the status for.
+
+ Returns:
+ defer.Deferred(bool): The requested value.
+ """
+
+ res = yield self.db.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="deactivated",
+ desc="get_user_deactivated_status",
+ )
+
+ # Convert the integer into a boolean.
+ return res == 1
+
+ def get_threepid_validation_session(
+ self, medium, client_secret, address=None, sid=None, validated=True
+ ):
+ """Gets a session_id and last_send_attempt (if available) for a
+ combination of validation metadata
+
+ Args:
+ medium (str|None): The medium of the 3PID
+ address (str|None): The address of the 3PID
+ sid (str|None): The ID of the validation session
+ client_secret (str): A unique string provided by the client to help identify this
+ validation attempt
+ validated (bool|None): Whether sessions should be filtered by
+ whether they have been validated already or not. None to
+ perform no filtering
+
+ Returns:
+ Deferred[dict|None]: A dict containing the following:
+ * address - address of the 3pid
+ * medium - medium of the 3pid
+ * client_secret - a secret provided by the client for this validation session
+ * session_id - ID of the validation session
+ * send_attempt - a number serving to dedupe send attempts for this session
+ * validated_at - timestamp of when this session was validated if so
+
+ Otherwise None if a validation session is not found
+ """
+ if not client_secret:
+ raise SynapseError(
+ 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM
+ )
+
+ keyvalues = {"client_secret": client_secret}
+ if medium:
+ keyvalues["medium"] = medium
+ if address:
+ keyvalues["address"] = address
+ if sid:
+ keyvalues["session_id"] = sid
+
+ assert address or sid
+
+ def get_threepid_validation_session_txn(txn):
+ sql = """
+ SELECT address, session_id, medium, client_secret,
+ last_send_attempt, validated_at
+ FROM threepid_validation_session WHERE %s
+ """ % (
+ " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
+ )
+
+ if validated is not None:
+ sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
+
+ sql += " LIMIT 1"
+
+ txn.execute(sql, list(keyvalues.values()))
+ rows = self.db.cursor_to_dict(txn)
+ if not rows:
+ return None
+
+ return rows[0]
+
+ return self.db.runInteraction(
+ "get_threepid_validation_session", get_threepid_validation_session_txn
+ )
+
+ def delete_threepid_session(self, session_id):
+ """Removes a threepid validation session from the database. This can
+ be done after validation has been performed and whatever action was
+ waiting on it has been carried out
+
+ Args:
+ session_id (str): The ID of the session to delete
+ """
+
+ def delete_threepid_session_txn(txn):
+ self.db.simple_delete_txn(
+ txn,
+ table="threepid_validation_token",
+ keyvalues={"session_id": session_id},
+ )
+ self.db.simple_delete_txn(
+ txn,
+ table="threepid_validation_session",
+ keyvalues={"session_id": session_id},
+ )
+
+ return self.db.runInteraction(
+ "delete_threepid_session", delete_threepid_session_txn
+ )
+
+
+class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.clock = hs.get_clock()
+ self.config = hs.config
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"access_tokens_device_index",
index_name="access_tokens_device_id",
table="access_tokens",
columns=["user_id", "device_id"],
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"users_creation_ts",
index_name="users_creation_ts",
table="users",
columns=["creation_ts"],
)
- self._account_validity = hs.config.account_validity
-
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
- self.register_noop_background_update("refresh_tokens_device_index")
+ self.db.updates.register_noop_background_update("refresh_tokens_device_index")
- self.register_background_update_handler(
- "user_threepids_grandfather", self._bg_user_threepids_grandfather,
+ self.db.updates.register_background_update_handler(
+ "user_threepids_grandfather", self._bg_user_threepids_grandfather
)
- self.register_background_update_handler(
- "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag,
- )
-
- # Create a background job for culling expired 3PID validity tokens
- hs.get_clock().looping_call(
- self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS,
+ self.db.updates.register_background_update_handler(
+ "users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
@defer.inlineCallbacks
- def _backgroud_update_set_deactivated_flag(self, progress, batch_size):
+ def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
last_user = progress.get("user_id", "")
- def _backgroud_update_set_deactivated_flag_txn(txn):
+ def _background_update_set_deactivated_flag_txn(txn):
txn.execute(
"""
SELECT
@@ -676,10 +885,10 @@ class RegistrationStore(
(last_user, batch_size),
)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
- return True
+ return True, 0
rows_processed_nb = 0
@@ -690,49 +899,111 @@ class RegistrationStore(
logger.info("Marked %d rows as deactivated", rows_processed_nb)
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
)
if batch_size > len(rows):
- return True
+ return True, len(rows)
else:
- return False
+ return False, len(rows)
- end = yield self.runInteraction(
- "users_set_deactivated_flag",
- _backgroud_update_set_deactivated_flag_txn,
+ end, nb_processed = yield self.db.runInteraction(
+ "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
- yield self._end_background_update("users_set_deactivated_flag")
+ yield self.db.updates._end_background_update("users_set_deactivated_flag")
- defer.returnValue(batch_size)
+ return nb_processed
@defer.inlineCallbacks
- def add_access_token_to_user(self, user_id, token, device_id=None):
+ def _bg_user_threepids_grandfather(self, progress, batch_size):
+ """We now track which identity servers a user binds their 3PID to, so
+ we need to handle the case of existing bindings where we didn't track
+ this.
+
+ We do this by grandfathering in existing user threepids assuming that
+ they used one of the server configured trusted identity servers.
+ """
+ id_servers = set(self.config.trusted_third_party_id_servers)
+
+ def _bg_user_threepids_grandfather_txn(txn):
+ sql = """
+ INSERT INTO user_threepid_id_server
+ (user_id, medium, address, id_server)
+ SELECT user_id, medium, address, ?
+ FROM user_threepids
+ """
+
+ txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+ if id_servers:
+ yield self.db.runInteraction(
+ "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
+ )
+
+ yield self.db.updates._end_background_update("user_threepids_grandfather")
+
+ return 1
+
+
+class RegistrationStore(RegistrationBackgroundUpdateStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RegistrationStore, self).__init__(database, db_conn, hs)
+
+ self._account_validity = hs.config.account_validity
+
+ if self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "account_validity_set_expiration_dates",
+ self._set_expiration_date_when_missing,
+ )
+
+ # Create a background job for culling expired 3PID validity tokens
+ def start_cull():
+ # run as a background process to make sure that the database transactions
+ # have a logcontext to report to
+ return run_as_background_process(
+ "cull_expired_threepid_validation_tokens",
+ self.cull_expired_threepid_validation_tokens,
+ )
+
+ hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
+
+ @defer.inlineCallbacks
+ def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
device_id (str): ID of the device to associate with the access
- token
+ token
+ valid_until_ms (int|None): when the token is valid until. None for
+ no expiry.
Raises:
StoreError if there was a problem adding this.
"""
next_id = self._access_tokens_id_gen.get_next()
- yield self._simple_insert(
+ yield self.db.simple_insert(
"access_tokens",
- {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
+ {
+ "id": next_id,
+ "user_id": user_id,
+ "token": token,
+ "device_id": device_id,
+ "valid_until_ms": valid_until_ms,
+ },
desc="add_access_token_to_user",
)
- def register(
+ def register_user(
self,
user_id,
- token=None,
password_hash=None,
was_guest=False,
make_guest=False,
@@ -745,9 +1016,6 @@ class RegistrationStore(
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user. If this
- is not None, the given access token is associated with the user
- id.
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
@@ -763,11 +1031,10 @@ class RegistrationStore(
Raises:
StoreError if the user_id could not be registered.
"""
- return self.runInteraction(
- "register",
- self._register,
+ return self.db.runInteraction(
+ "register_user",
+ self._register_user,
user_id,
- token,
password_hash,
was_guest,
make_guest,
@@ -777,11 +1044,10 @@ class RegistrationStore(
user_type,
)
- def _register(
+ def _register_user(
self,
txn,
user_id,
- token,
password_hash,
was_guest,
make_guest,
@@ -794,14 +1060,12 @@ class RegistrationStore(
now = int(self.clock.time())
- next_id = self._access_tokens_id_gen.get_next()
-
try:
if was_guest:
# Ensure that the guest user actually exists
# ``allow_none=False`` makes this raise an exception
# if the row isn't in the database.
- self._simple_select_one_txn(
+ self.db.simple_select_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -809,7 +1073,7 @@ class RegistrationStore(
allow_none=False,
)
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -823,7 +1087,7 @@ class RegistrationStore(
},
)
else:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"users",
values={
@@ -843,14 +1107,6 @@ class RegistrationStore(
if self._account_validity.enabled:
self.set_expiration_date_for_user_txn(txn, user_id)
- if token:
- # it's possible for this to get a conflict, but only for a single user
- # since tokens are namespaced based on their user ID
- txn.execute(
- "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
- (next_id, user_id, token),
- )
-
if create_profile_with_displayname:
# set a default displayname serverside to avoid ugly race
# between auto-joins and clients trying to set displaynames
@@ -862,9 +1118,40 @@ class RegistrationStore(
(user_id_obj.localpart, create_profile_with_displayname),
)
+ if self.hs.config.stats_enabled:
+ # we create a new completed user statistics row
+
+ # we don't strictly need current_token since this user really can't
+ # have any state deltas before now (as it is a new user), but still,
+ # we include it for completeness.
+ current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
+ self._update_stats_delta_txn(
+ txn, now, "user", user_id, {}, complete_with_stream_id=current_token
+ )
+
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
+ def record_user_external_id(
+ self, auth_provider: str, external_id: str, user_id: str
+ ) -> Deferred:
+ """Record a mapping from an external user id to a mxid
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+ user_id: complete mxid that it is mapped to
+ """
+ return self.db.simple_insert(
+ table="user_external_ids",
+ values={
+ "auth_provider": auth_provider,
+ "external_id": external_id,
+ "user_id": user_id,
+ },
+ desc="record_user_external_id",
+ )
+
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
@@ -873,12 +1160,14 @@ class RegistrationStore(
"""
def user_set_password_hash_txn(txn):
- self._simple_update_one_txn(
- txn, 'users', {'name': user_id}, {'password_hash': password_hash}
+ self.db.simple_update_one_txn(
+ txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
+ return self.db.runInteraction(
+ "user_set_password_hash", user_set_password_hash_txn
+ )
def user_set_consent_version(self, user_id, consent_version):
"""Updates the user table to record privacy policy consent
@@ -893,15 +1182,15 @@ class RegistrationStore(
"""
def f(txn):
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
- table='users',
- keyvalues={'name': user_id},
- updatevalues={'consent_version': consent_version},
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"consent_version": consent_version},
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.runInteraction("user_set_consent_version", f)
+ return self.db.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
"""Updates the user table to record that we have sent the user a server
@@ -917,15 +1206,15 @@ class RegistrationStore(
"""
def f(txn):
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
- table='users',
- keyvalues={'name': user_id},
- updatevalues={'consent_server_notice_sent': consent_version},
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"consent_server_notice_sent": consent_version},
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.runInteraction("user_set_consent_server_notice_sent", f)
+ return self.db.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
"""
@@ -971,11 +1260,11 @@ class RegistrationStore(
return tokens_and_devices
- return self.runInteraction("user_delete_access_tokens", f)
+ return self.db.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
def f(txn):
- self._simple_delete_one_txn(
+ self.db.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@@ -983,11 +1272,11 @@ class RegistrationStore(
txn, self.get_user_by_access_token, (access_token,)
)
- return self.runInteraction("delete_access_token", f)
+ return self.db.runInteraction("delete_access_token", f)
@cachedInlineCallbacks()
def is_guest(self, user_id):
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@@ -995,48 +1284,14 @@ class RegistrationStore(
desc="is_guest",
)
- defer.returnValue(res if res else False)
-
- @defer.inlineCallbacks
- def save_or_get_3pid_guest_access_token(
- self, medium, address, access_token, inviter_user_id
- ):
- """
- Gets the 3pid's guest access token if exists, else saves access_token.
-
- Args:
- medium (str): Medium of the 3pid. Must be "email".
- address (str): 3pid address.
- access_token (str): The access token to persist if none is
- already persisted.
- inviter_user_id (str): User ID of the inviter.
-
- Returns:
- deferred str: Whichever access token is persisted at the end
- of this function call.
- """
-
- def insert(txn):
- txn.execute(
- "INSERT INTO threepid_guest_access_tokens "
- "(medium, address, guest_access_token, first_inviter) "
- "VALUES (?, ?, ?, ?)",
- (medium, address, access_token, inviter_user_id),
- )
-
- try:
- yield self.runInteraction("save_3pid_guest_access_token", insert)
- defer.returnValue(access_token)
- except self.database_engine.module.IntegrityError:
- ret = yield self.get_3pid_guest_access_token(medium, address)
- defer.returnValue(ret)
+ return res if res else False
def add_user_pending_deactivation(self, user_id):
"""
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
- return self._simple_insert(
+ return self.db.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
@@ -1049,7 +1304,7 @@ class RegistrationStore(
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
- return self._simple_delete(
+ return self.db.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
@@ -1060,7 +1315,7 @@ class RegistrationStore(
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@@ -1068,104 +1323,7 @@ class RegistrationStore(
desc="get_users_pending_deactivation",
)
- @defer.inlineCallbacks
- def _bg_user_threepids_grandfather(self, progress, batch_size):
- """We now track which identity servers a user binds their 3PID to, so
- we need to handle the case of existing bindings where we didn't track
- this.
-
- We do this by grandfathering in existing user threepids assuming that
- they used one of the server configured trusted identity servers.
- """
- id_servers = set(self.config.trusted_third_party_id_servers)
-
- def _bg_user_threepids_grandfather_txn(txn):
- sql = """
- INSERT INTO user_threepid_id_server
- (user_id, medium, address, id_server)
- SELECT user_id, medium, address, ?
- FROM user_threepids
- """
-
- txn.executemany(sql, [(id_server,) for id_server in id_servers])
-
- if id_servers:
- yield self.runInteraction(
- "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
- )
-
- yield self._end_background_update("user_threepids_grandfather")
-
- defer.returnValue(1)
-
- def get_threepid_validation_session(
- self,
- medium,
- client_secret,
- address=None,
- sid=None,
- validated=True,
- ):
- """Gets a session_id and last_send_attempt (if available) for a
- client_secret/medium/(address|session_id) combo
-
- Args:
- medium (str|None): The medium of the 3PID
- address (str|None): The address of the 3PID
- sid (str|None): The ID of the validation session
- client_secret (str|None): A unique string provided by the client to
- help identify this validation attempt
- validated (bool|None): Whether sessions should be filtered by
- whether they have been validated already or not. None to
- perform no filtering
-
- Returns:
- deferred {str, int}|None: A dict containing the
- latest session_id and send_attempt count for this 3PID.
- Otherwise None if there hasn't been a previous attempt
- """
- keyvalues = {
- "medium": medium,
- "client_secret": client_secret,
- }
- if address:
- keyvalues["address"] = address
- if sid:
- keyvalues["session_id"] = sid
-
- assert(address or sid)
-
- def get_threepid_validation_session_txn(txn):
- sql = """
- SELECT address, session_id, medium, client_secret,
- last_send_attempt, validated_at
- FROM threepid_validation_session WHERE %s
- """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),)
-
- if validated is not None:
- sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
-
- sql += " LIMIT 1"
-
- txn.execute(sql, list(keyvalues.values()))
- rows = self.cursor_to_dict(txn)
- if not rows:
- return None
-
- return rows[0]
-
- return self.runInteraction(
- "get_threepid_validation_session",
- get_threepid_validation_session_txn,
- )
-
- def validate_threepid_session(
- self,
- session_id,
- client_secret,
- token,
- current_ts,
- ):
+ def validate_threepid_session(self, session_id, client_secret, token, current_ts):
"""Attempt to validate a threepid session using a token
Args:
@@ -1176,13 +1334,18 @@ class RegistrationStore(
current_ts (int): The current unix time in milliseconds. Used for
checking token expiry status
+ Raises:
+ ThreepidValidationError: if a matching validation token was not found or has
+ expired
+
Returns:
deferred str|None: A str representing a link to redirect the user
to if there is one.
"""
+
# Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn):
- row = self._simple_select_one_txn(
+ row = self.db.simple_select_one_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1197,10 +1360,10 @@ class RegistrationStore(
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id",
+ 400, "This client_secret does not match the provided session_id"
)
- row = self._simple_select_one_txn(
+ row = self.db.simple_select_one_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token},
@@ -1210,7 +1373,7 @@ class RegistrationStore(
if not row:
raise ThreepidValidationError(
- 400, "Validation token not found or has expired",
+ 400, "Validation token not found or has expired"
)
expires = row["expires"]
next_link = row["next_link"]
@@ -1221,11 +1384,11 @@ class RegistrationStore(
if expires <= current_ts:
raise ThreepidValidationError(
- 400, "This token has expired. Please request a new one",
+ 400, "This token has expired. Please request a new one"
)
# Looks good. Validate the session
- self._simple_update_txn(
+ self.db.simple_update_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1235,9 +1398,8 @@ class RegistrationStore(
return next_link
# Return next_link if it exists
- return self.runInteraction(
- "validate_threepid_session_txn",
- validate_threepid_session_txn,
+ return self.db.runInteraction(
+ "validate_threepid_session_txn", validate_threepid_session_txn
)
def upsert_threepid_validation_session(
@@ -1269,7 +1431,7 @@ class RegistrationStore(
if validated_at:
insertion_values["validated_at"] = validated_at
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="threepid_validation_session",
keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt},
@@ -1304,9 +1466,10 @@ class RegistrationStore(
token_expires (int): The timestamp for which after the token
will no longer be valid
"""
+
def start_or_continue_validation_session_txn(txn):
# Create or update a validation session
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1319,7 +1482,7 @@ class RegistrationStore(
)
# Create a new validation token with this session ID
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="threepid_validation_token",
values={
@@ -1330,13 +1493,14 @@ class RegistrationStore(
},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
def cull_expired_threepid_validation_tokens(self):
"""Remove threepid validation tokens with expiry dates that have passed"""
+
def cull_expired_threepid_validation_tokens_txn(txn, ts):
sql = """
DELETE FROM threepid_validation_token WHERE
@@ -1344,80 +1508,91 @@ class RegistrationStore(
"""
return txn.execute(sql, (ts,))
- return self.runInteraction(
+ return self.db.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
)
- def delete_threepid_session(self, session_id):
- """Removes a threepid validation session from the database. This can
- be done after validation has been performed and whatever action was
- waiting on it has been carried out
+ @defer.inlineCallbacks
+ def set_user_deactivated_status(self, user_id, deactivated):
+ """Set the `deactivated` property for the provided user to the provided value.
Args:
- session_id (str): The ID of the session to delete
+ user_id (str): The ID of the user to set the status for.
+ deactivated (bool): The value to set for `deactivated`.
"""
- def delete_threepid_session_txn(txn):
- self._simple_delete_txn(
- txn,
- table="threepid_validation_token",
- keyvalues={"session_id": session_id},
- )
- self._simple_delete_txn(
- txn,
- table="threepid_validation_session",
- keyvalues={"session_id": session_id},
- )
- return self.runInteraction(
- "delete_threepid_session",
- delete_threepid_session_txn,
+ yield self.db.runInteraction(
+ "set_user_deactivated_status",
+ self.set_user_deactivated_status_txn,
+ user_id,
+ deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
- txn, self.get_user_deactivated_status, (user_id,),
+ txn, self.get_user_deactivated_status, (user_id,)
)
@defer.inlineCallbacks
- def set_user_deactivated_status(self, user_id, deactivated):
- """Set the `deactivated` property for the provided user to the provided value.
-
- Args:
- user_id (str): The ID of the user to set the status for.
- deactivated (bool): The value to set for `deactivated`.
+ def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
"""
- yield self.runInteraction(
- "set_user_deactivated_status",
- self.set_user_deactivated_status_txn,
- user_id, deactivated,
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database, filtering out deactivated users.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+ )
+ txn.execute(sql, [])
+
+ res = self.db.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn, user["name"], use_delta=True
+ )
+
+ yield self.db.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
)
- @cachedInlineCallbacks()
- def get_user_deactivated_status(self, user_id):
- """Retrieve the value for the `deactivated` property for the provided user.
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
Args:
- user_id (str): The ID of the user to retrieve the status for.
-
- Returns:
- defer.Deferred(bool): The requested value.
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
"""
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
- res = yield self._simple_select_one_onecol(
- table="users",
- keyvalues={"name": user_id},
- retcol="deactivated",
- desc="get_user_deactivated_status",
- )
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
- # Convert the integer into a boolean.
- defer.returnValue(res == 1)
+ self.db.simple_upsert_txn(
+ txn,
+ "account_validity",
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+ )
diff --git a/synapse/storage/rejections.py b/synapse/storage/data_stores/main/rejections.py
index f4c1c2a457..1c07c7a425 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -15,14 +15,14 @@
import logging
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def _store_rejections_txn(self, txn, event_id, reason):
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="rejections",
values={
@@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore):
)
def get_rejection_reason(self, event_id):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
new file mode 100644
index 0000000000..046c2b4845
--- /dev/null
+++ b/synapse/storage/data_stores/main/relations.py
@@ -0,0 +1,385 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+
+from synapse.api.constants import RelationTypes
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.stream import generate_pagination_where_clause
+from synapse.storage.relations import (
+ AggregationPaginationToken,
+ PaginationChunk,
+ RelationPaginationToken,
+)
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+class RelationsWorkerStore(SQLBaseStore):
+ @cached(tree=True)
+ def get_relations_for_event(
+ self,
+ event_id,
+ relation_type=None,
+ event_type=None,
+ aggregation_key=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of relations for an event, ordered by topological ordering.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ relation_type (str|None): Only fetch events with this relation
+ type, if given.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ aggregation_key (str|None): Only fetch events with this aggregation
+ key, if given.
+ limit (int): Only fetch the most recent `limit` events.
+ direction (str): Whether to fetch the most recent first (`"b"`) or
+ the oldest first (`"f"`).
+ from_token (RelationPaginationToken|None): Fetch rows from the given
+ token, or from the start if None.
+ to_token (RelationPaginationToken|None): Fetch rows up to the given
+ token, or up to the end if None.
+
+ Returns:
+ Deferred[PaginationChunk]: List of event IDs that match relations
+ requested. The rows are of the form `{"event_id": "..."}`.
+ """
+
+ where_clause = ["relates_to_id = ?"]
+ where_args = [event_id]
+
+ if relation_type is not None:
+ where_clause.append("relation_type = ?")
+ where_args.append(relation_type)
+
+ if event_type is not None:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ if aggregation_key:
+ where_clause.append("aggregation_key = ?")
+ where_args.append(aggregation_key)
+
+ pagination_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if pagination_clause:
+ where_clause.append(pagination_clause)
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ sql = """
+ SELECT event_id, topological_ordering, stream_ordering
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
+
+ def _get_recent_references_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ last_topo_id = None
+ last_stream_id = None
+ events = []
+ for row in txn:
+ events.append({"event_id": row[0]})
+ last_topo_id = row[1]
+ last_stream_id = row[2]
+
+ next_batch = None
+ if len(events) > limit and last_topo_id and last_stream_id:
+ next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.db.runInteraction(
+ "get_recent_references_for_event", _get_recent_references_for_event_txn
+ )
+
+ @cached(tree=True)
+ def get_aggregation_groups_for_event(
+ self,
+ event_id,
+ event_type=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of annotations on the event, grouped by event type and
+ aggregation key, sorted by count.
+
+ This is used e.g. to get the what and how many reactions have happend
+ on an event.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ limit (int): Only fetch the `limit` groups.
+ direction (str): Whether to fetch the highest count first (`"b"`) or
+ the lowest count first (`"f"`).
+ from_token (AggregationPaginationToken|None): Fetch rows from the
+ given token, or from the start if None.
+ to_token (AggregationPaginationToken|None): Fetch rows up to the
+ given token, or up to the end if None.
+
+
+ Returns:
+ Deferred[PaginationChunk]: List of groups of annotations that
+ match. Each row is a dict with `type`, `key` and `count` fields.
+ """
+
+ where_clause = ["relates_to_id = ?", "relation_type = ?"]
+ where_args = [event_id, RelationTypes.ANNOTATION]
+
+ if event_type:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ having_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("COUNT(*)", "MAX(stream_ordering)"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ if having_clause:
+ having_clause = "HAVING " + having_clause
+ else:
+ having_clause = ""
+
+ sql = """
+ SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE {where_clause}
+ GROUP BY relation_type, type, aggregation_key
+ {having_clause}
+ ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ LIMIT ?
+ """.format(
+ where_clause=" AND ".join(where_clause),
+ order=order,
+ having_clause=having_clause,
+ )
+
+ def _get_aggregation_groups_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ next_batch = None
+ events = []
+ for row in txn:
+ events.append({"type": row[0], "key": row[1], "count": row[2]})
+ next_batch = AggregationPaginationToken(row[2], row[3])
+
+ if len(events) <= limit:
+ next_batch = None
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.db.runInteraction(
+ "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+ )
+
+ @cachedInlineCallbacks()
+ def get_applicable_edit(self, event_id):
+ """Get the most recent edit (if any) that has happened for the given
+ event.
+
+ Correctly handles checking whether edits were allowed to happen.
+
+ Args:
+ event_id (str): The original event ID
+
+ Returns:
+ Deferred[EventBase|None]: Returns the most recent edit, if any.
+ """
+
+ # We only allow edits for `m.room.message` events that have the same sender
+ # and event type. We can't assert these things during regular event auth so
+ # we have to do the checks post hoc.
+
+ # Fetches latest edit that has the same type and sender as the
+ # original, and is an `m.room.message`.
+ sql = """
+ SELECT edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by edit.origin_server_ts DESC, edit.event_id DESC
+ LIMIT 1
+ """
+
+ def _get_applicable_edit_txn(txn):
+ txn.execute(sql, (event_id, RelationTypes.REPLACE))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ edit_id = yield self.db.runInteraction(
+ "get_applicable_edit", _get_applicable_edit_txn
+ )
+
+ if not edit_id:
+ return
+
+ edit_event = yield self.get_event(edit_id, allow_none=True)
+ return edit_event
+
+ def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+ """Check if a user has already annotated an event with the same key
+ (e.g. already liked an event).
+
+ Args:
+ parent_id (str): The event being annotated
+ event_type (str): The event type of the annotation
+ aggregation_key (str): The aggregation key of the annotation
+ sender (str): The sender of the annotation
+
+ Returns:
+ Deferred[bool]
+ """
+
+ sql = """
+ SELECT 1 FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND type = ?
+ AND sender = ?
+ AND aggregation_key = ?
+ LIMIT 1;
+ """
+
+ def _get_if_user_has_annotated_event(txn):
+ txn.execute(
+ sql,
+ (
+ parent_id,
+ RelationTypes.ANNOTATION,
+ event_type,
+ sender,
+ aggregation_key,
+ ),
+ )
+
+ return bool(txn.fetchone())
+
+ return self.db.runInteraction(
+ "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
+ )
+
+
+class RelationsStore(RelationsWorkerStore):
+ def _handle_event_relations(self, txn, event):
+ """Handles inserting relation data during peristence of events
+
+ Args:
+ txn
+ event (EventBase)
+ """
+ relation = event.content.get("m.relates_to")
+ if not relation:
+ # No relations
+ return
+
+ rel_type = relation.get("rel_type")
+ if rel_type not in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
+ # Unknown relation type
+ return
+
+ parent_id = relation.get("event_id")
+ if not parent_id:
+ # Invalid relation
+ return
+
+ aggregation_key = relation.get("key")
+
+ self.db.simple_insert_txn(
+ txn,
+ table="event_relations",
+ values={
+ "event_id": event.event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ "aggregation_key": aggregation_key,
+ },
+ )
+
+ txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
+ txn.call_after(
+ self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+ )
+
+ if rel_type == RelationTypes.REPLACE:
+ txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
+
+ def _handle_redaction(self, txn, redacted_event_id):
+ """Handles receiving a redaction and checking whether we need to remove
+ any redacted relations from the database.
+
+ Args:
+ txn
+ redacted_event_id (str): The event that was redacted.
+ """
+
+ self.db.simple_delete_txn(
+ txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
+ )
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
new file mode 100644
index 0000000000..511316938d
--- /dev/null
+++ b/synapse/storage/data_stores/main/room.py
@@ -0,0 +1,1404 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import logging
+import re
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple
+
+from six import integer_types
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import StoreError
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.search import SearchStore
+from synapse.storage.database import Database, LoggingTransaction
+from synapse.types import ThirdPartyInstanceID
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+OpsLevel = collections.namedtuple(
+ "OpsLevel", ("ban_level", "kick_level", "redact_level")
+)
+
+RatelimitOverride = collections.namedtuple(
+ "RatelimitOverride", ("messages_per_second", "burst_count")
+)
+
+
+class RoomSortOrder(Enum):
+ """
+ Enum to define the sorting method used when returning rooms with get_rooms_paginate
+
+ ALPHABETICAL = sort rooms alphabetically by name
+ SIZE = sort rooms by membership size, highest to lowest
+ """
+
+ ALPHABETICAL = "alphabetical"
+ SIZE = "size"
+
+
+class RoomWorkerStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+
+ self.config = hs.config
+
+ def get_room(self, room_id):
+ """Retrieve a room.
+
+ Args:
+ room_id (str): The ID of the room to retrieve.
+ Returns:
+ A dict containing the room information, or None if the room is unknown.
+ """
+ return self.db.simple_select_one(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("room_id", "is_public", "creator"),
+ desc="get_room",
+ allow_none=True,
+ )
+
+ def get_public_room_ids(self):
+ return self.db.simple_select_onecol(
+ table="rooms",
+ keyvalues={"is_public": True},
+ retcol="room_id",
+ desc="get_public_room_ids",
+ )
+
+ def count_public_rooms(self, network_tuple, ignore_non_federatable):
+ """Counts the number of public rooms as tracked in the room_stats_current
+ and room_stats_state table.
+
+ Args:
+ network_tuple (ThirdPartyInstanceID|None)
+ ignore_non_federatable (bool): If true filters out non-federatable rooms
+ """
+
+ def _count_public_rooms_txn(txn):
+ query_args = []
+
+ if network_tuple:
+ if network_tuple.appservice_id:
+ published_sql = """
+ SELECT room_id from appservice_room_list
+ WHERE appservice_id = ? AND network_id = ?
+ """
+ query_args.append(network_tuple.appservice_id)
+ query_args.append(network_tuple.network_id)
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ """
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ UNION SELECT room_id from appservice_room_list
+ """
+
+ sql = """
+ SELECT
+ COALESCE(COUNT(*), 0)
+ FROM (
+ %(published_sql)s
+ ) published
+ INNER JOIN room_stats_state USING (room_id)
+ INNER JOIN room_stats_current USING (room_id)
+ WHERE
+ (
+ join_rules = 'public' OR history_visibility = 'world_readable'
+ )
+ AND joined_members > 0
+ """ % {
+ "published_sql": published_sql
+ }
+
+ txn.execute(sql, query_args)
+ return txn.fetchone()[0]
+
+ return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
+
+ @defer.inlineCallbacks
+ def get_largest_public_rooms(
+ self,
+ network_tuple: Optional[ThirdPartyInstanceID],
+ search_filter: Optional[dict],
+ limit: Optional[int],
+ bounds: Optional[Tuple[int, str]],
+ forwards: bool,
+ ignore_non_federatable: bool = False,
+ ):
+ """Gets the largest public rooms (where largest is in terms of joined
+ members, as tracked in the statistics table).
+
+ Args:
+ network_tuple
+ search_filter
+ limit: Maxmimum number of rows to return, unlimited otherwise.
+ bounds: An uppoer or lower bound to apply to result set if given,
+ consists of a joined member count and room_id (these are
+ excluded from result set).
+ forwards: true iff going forwards, going backwards otherwise
+ ignore_non_federatable: If true filters out non-federatable rooms.
+
+ Returns:
+ Rooms in order: biggest number of joined users first.
+ We then arbitrarily use the room_id as a tie breaker.
+
+ """
+
+ where_clauses = []
+ query_args = []
+
+ if network_tuple:
+ if network_tuple.appservice_id:
+ published_sql = """
+ SELECT room_id from appservice_room_list
+ WHERE appservice_id = ? AND network_id = ?
+ """
+ query_args.append(network_tuple.appservice_id)
+ query_args.append(network_tuple.network_id)
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ """
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ UNION SELECT room_id from appservice_room_list
+ """
+
+ # Work out the bounds if we're given them, these bounds look slightly
+ # odd, but are designed to help query planner use indices by pulling
+ # out a common bound.
+ if bounds:
+ last_joined_members, last_room_id = bounds
+ if forwards:
+ where_clauses.append(
+ """
+ joined_members <= ? AND (
+ joined_members < ? OR room_id < ?
+ )
+ """
+ )
+ else:
+ where_clauses.append(
+ """
+ joined_members >= ? AND (
+ joined_members > ? OR room_id > ?
+ )
+ """
+ )
+
+ query_args += [last_joined_members, last_joined_members, last_room_id]
+
+ if ignore_non_federatable:
+ where_clauses.append("is_federatable")
+
+ if search_filter and search_filter.get("generic_search_term", None):
+ search_term = "%" + search_filter["generic_search_term"] + "%"
+
+ where_clauses.append(
+ """
+ (
+ LOWER(name) LIKE ?
+ OR LOWER(topic) LIKE ?
+ OR LOWER(canonical_alias) LIKE ?
+ )
+ """
+ )
+ query_args += [
+ search_term.lower(),
+ search_term.lower(),
+ search_term.lower(),
+ ]
+
+ where_clause = ""
+ if where_clauses:
+ where_clause = " AND " + " AND ".join(where_clauses)
+
+ sql = """
+ SELECT
+ room_id, name, topic, canonical_alias, joined_members,
+ avatar, history_visibility, joined_members, guest_access
+ FROM (
+ %(published_sql)s
+ ) published
+ INNER JOIN room_stats_state USING (room_id)
+ INNER JOIN room_stats_current USING (room_id)
+ WHERE
+ (
+ join_rules = 'public' OR history_visibility = 'world_readable'
+ )
+ AND joined_members > 0
+ %(where_clause)s
+ ORDER BY joined_members %(dir)s, room_id %(dir)s
+ """ % {
+ "published_sql": published_sql,
+ "where_clause": where_clause,
+ "dir": "DESC" if forwards else "ASC",
+ }
+
+ if limit is not None:
+ query_args.append(limit)
+
+ sql += """
+ LIMIT ?
+ """
+
+ def _get_largest_public_rooms_txn(txn):
+ txn.execute(sql, query_args)
+
+ results = self.db.cursor_to_dict(txn)
+
+ if not forwards:
+ results.reverse()
+
+ return results
+
+ ret_val = yield self.db.runInteraction(
+ "get_largest_public_rooms", _get_largest_public_rooms_txn
+ )
+ defer.returnValue(ret_val)
+
+ @cached(max_entries=10000)
+ def is_room_blocked(self, room_id):
+ return self.db.simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="1",
+ allow_none=True,
+ desc="is_room_blocked",
+ )
+
+ @defer.inlineCallbacks
+ def is_room_published(self, room_id):
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id (str)
+ Returns:
+ bool: Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = yield self.get_room(room_id)
+ if not room_info:
+ defer.returnValue(False)
+
+ # Check the is_public value
+ defer.returnValue(room_info.get("is_public", False))
+
+ async def get_rooms_paginate(
+ self,
+ start: int,
+ limit: int,
+ order_by: RoomSortOrder,
+ reverse_order: bool,
+ search_term: Optional[str],
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ """Function to retrieve a paginated list of rooms as json.
+
+ Args:
+ start: offset in the list
+ limit: maximum amount of rooms to retrieve
+ order_by: the sort order of the returned list
+ reverse_order: whether to reverse the room list
+ search_term: a string to filter room names by
+ Returns:
+ A list of room dicts and an integer representing the total number of
+ rooms that exist given this query
+ """
+ # Filter room names by a string
+ where_statement = ""
+ if search_term:
+ where_statement = "WHERE state.name LIKE ?"
+
+ # Our postgres db driver converts ? -> %s in SQL strings as that's the
+ # placeholder for postgres.
+ # HOWEVER, if you put a % into your SQL then everything goes wibbly.
+ # To get around this, we're going to surround search_term with %'s
+ # before giving it to the database in python instead
+ search_term = "%" + search_term + "%"
+
+ # Set ordering
+ if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
+ order_by_column = "curr.joined_members"
+ order_by_asc = False
+ elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL:
+ # Sort alphabetically
+ order_by_column = "state.name"
+ order_by_asc = True
+ else:
+ raise StoreError(
+ 500, "Incorrect value for order_by provided: %s" % order_by
+ )
+
+ # Whether to return the list in reverse order
+ if reverse_order:
+ # Flip the boolean
+ order_by_asc = not order_by_asc
+
+ # Create one query for getting the limited number of events that the user asked
+ # for, and another query for getting the total number of events that could be
+ # returned. Thus allowing us to see if there are more events to paginate through
+ info_sql = """
+ SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members
+ FROM room_stats_state state
+ INNER JOIN room_stats_current curr USING (room_id)
+ %s
+ ORDER BY %s %s
+ LIMIT ?
+ OFFSET ?
+ """ % (
+ where_statement,
+ order_by_column,
+ "ASC" if order_by_asc else "DESC",
+ )
+
+ # Use a nested SELECT statement as SQL can't count(*) with an OFFSET
+ count_sql = """
+ SELECT count(*) FROM (
+ SELECT room_id FROM room_stats_state state
+ %s
+ ) AS get_room_ids
+ """ % (
+ where_statement,
+ )
+
+ def _get_rooms_paginate_txn(txn):
+ # Execute the data query
+ sql_values = (limit, start)
+ if search_term:
+ # Add the search term into the WHERE clause
+ sql_values = (search_term,) + sql_values
+ txn.execute(info_sql, sql_values)
+
+ # Refactor room query data into a structured dictionary
+ rooms = []
+ for room in txn:
+ rooms.append(
+ {
+ "room_id": room[0],
+ "name": room[1],
+ "canonical_alias": room[2],
+ "joined_members": room[3],
+ }
+ )
+
+ # Execute the count query
+
+ # Add the search term into the WHERE clause if present
+ sql_values = (search_term,) if search_term else ()
+ txn.execute(count_sql, sql_values)
+
+ room_count = txn.fetchone()
+ return rooms, room_count[0]
+
+ return await self.db.runInteraction(
+ "get_rooms_paginate", _get_rooms_paginate_txn,
+ )
+
+ @cachedInlineCallbacks(max_entries=10000)
+ def get_ratelimit_for_user(self, user_id):
+ """Check if there are any overrides for ratelimiting for the given
+ user
+
+ Args:
+ user_id (str)
+
+ Returns:
+ RatelimitOverride if there is an override, else None. If the contents
+ of RatelimitOverride are None or 0 then ratelimitng has been
+ disabled for that user entirely.
+ """
+ row = yield self.db.simple_select_one(
+ table="ratelimit_override",
+ keyvalues={"user_id": user_id},
+ retcols=("messages_per_second", "burst_count"),
+ allow_none=True,
+ desc="get_ratelimit_for_user",
+ )
+
+ if row:
+ return RatelimitOverride(
+ messages_per_second=row["messages_per_second"],
+ burst_count=row["burst_count"],
+ )
+ else:
+ return None
+
+ @cachedInlineCallbacks()
+ def get_retention_policy_for_room(self, room_id):
+ """Get the retention policy for a given room.
+
+ If no retention policy has been found for this room, returns a policy defined
+ by the configured default policy (which has None as both the 'min_lifetime' and
+ the 'max_lifetime' if no default policy has been defined in the server's
+ configuration).
+
+ Args:
+ room_id (str): The ID of the room to get the retention policy of.
+
+ Returns:
+ dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+ """
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ defer.returnValue({"min_lifetime": None, "max_lifetime": None})
+
+ def get_retention_policy_for_room_txn(txn):
+ txn.execute(
+ """
+ SELECT min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ WHERE room_id = ?;
+ """,
+ (room_id,),
+ )
+
+ return self.db.cursor_to_dict(txn)
+
+ ret = yield self.db.runInteraction(
+ "get_retention_policy_for_room", get_retention_policy_for_room_txn,
+ )
+
+ # If we don't know this room ID, ret will be None, in this case return the default
+ # policy.
+ if not ret:
+ defer.returnValue(
+ {
+ "min_lifetime": self.config.retention_default_min_lifetime,
+ "max_lifetime": self.config.retention_default_max_lifetime,
+ }
+ )
+
+ row = ret[0]
+
+ # If one of the room's policy's attributes isn't defined, use the matching
+ # attribute from the default policy.
+ # The default values will be None if no default policy has been defined, or if one
+ # of the attributes is missing from the default policy.
+ if row["min_lifetime"] is None:
+ row["min_lifetime"] = self.config.retention_default_min_lifetime
+
+ if row["max_lifetime"] is None:
+ row["max_lifetime"] = self.config.retention_default_max_lifetime
+
+ defer.returnValue(row)
+
+ def get_media_mxcs_in_room(self, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+
+ def _get_media_mxcs_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ # Convert the IDs to MXC URIs
+ for media_id in local_mxcs:
+ local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
+ for hostname, media_id in remote_mxcs:
+ remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
+
+ return self.db.runInteraction(
+ "get_media_ids_in_room", _get_media_mxcs_in_room_txn
+ )
+
+ def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+ """For a room loops through all events with media and quarantines
+ the associated media
+ """
+
+ logger.info("Quarantining media in room: %s", room_id)
+
+ def _quarantine_media_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ total_media_quarantined = 0
+
+ # Now update all the tables to set the quarantined_by flag
+
+ txn.executemany(
+ """
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """,
+ ((quarantined_by, media_id) for media_id in local_mxcs),
+ )
+
+ txn.executemany(
+ """
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_mxcs
+ ),
+ )
+
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
+
+ return total_media_quarantined
+
+ return self.db.runInteraction(
+ "quarantine_media_in_room", _quarantine_media_in_room_txn
+ )
+
+ def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ txn (cursor)
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ sql = """
+ SELECT stream_ordering, json FROM events
+ JOIN event_json USING (room_id, event_id)
+ WHERE room_id = ?
+ %(where_clause)s
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100))
+
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ while True:
+ next_token = None
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ event_json = json.loads(content_json)
+ content = event_json["content"]
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hs.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ if next_token is None:
+ # We've gone through the whole room, so we're finished.
+ break
+
+ txn.execute(
+ sql % {"where_clause": "AND stream_ordering < ?"},
+ (room_id, next_token, True, False, 100),
+ )
+
+ return local_media_mxcs, remote_media_mxcs
+
+ def quarantine_media_by_id(
+ self, server_name: str, media_id: str, quarantined_by: str,
+ ):
+ """quarantines a single local or remote media id
+
+ Args:
+ server_name: The name of the server that holds this media
+ media_id: The ID of the media to be quarantined
+ quarantined_by: The user ID that initiated the quarantine request
+ """
+ logger.info("Quarantining media: %s/%s", server_name, media_id)
+ is_local = server_name == self.config.server_name
+
+ def _quarantine_media_by_id_txn(txn):
+ local_mxcs = [media_id] if is_local else []
+ remote_mxcs = [(server_name, media_id)] if not is_local else []
+
+ return self._quarantine_media_txn(
+ txn, local_mxcs, remote_mxcs, quarantined_by
+ )
+
+ return self.db.runInteraction(
+ "quarantine_media_by_user", _quarantine_media_by_id_txn
+ )
+
+ def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+ """quarantines all local media associated with a single user
+
+ Args:
+ user_id: The ID of the user to quarantine media of
+ quarantined_by: The ID of the user who made the quarantine request
+ """
+
+ def _quarantine_media_by_user_txn(txn):
+ local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
+ return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
+
+ return self.db.runInteraction(
+ "quarantine_media_by_user", _quarantine_media_by_user_txn
+ )
+
+ def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+ """Retrieves local media IDs by a given user
+
+ Args:
+ txn (cursor)
+ user_id: The ID of the user to retrieve media IDs of
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ # Local media
+ sql = """
+ SELECT media_id
+ FROM local_media_repository
+ WHERE user_id = ?
+ """
+ if filter_quarantined:
+ sql += "AND quarantined_by IS NULL"
+ txn.execute(sql, (user_id,))
+
+ local_media_ids = [row[0] for row in txn]
+
+ # TODO: Figure out all remote media a user has referenced in a message
+
+ return local_media_ids
+
+ def _quarantine_media_txn(
+ self,
+ txn,
+ local_mxcs: List[str],
+ remote_mxcs: List[Tuple[str, str]],
+ quarantined_by: str,
+ ) -> int:
+ """Quarantine local and remote media items
+
+ Args:
+ txn (cursor)
+ local_mxcs: A list of local mxc URLs
+ remote_mxcs: A list of (remote server, media id) tuples representing
+ remote mxc URLs
+ quarantined_by: The ID of the user who initiated the quarantine request
+ Returns:
+ The total number of media items quarantined
+ """
+ total_media_quarantined = 0
+
+ # Update all the tables to set the quarantined_by flag
+ txn.executemany(
+ """
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """,
+ ((quarantined_by, media_id) for media_id in local_mxcs),
+ )
+
+ txn.executemany(
+ """
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
+ )
+
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
+
+ return total_media_quarantined
+
+
+class RoomBackgroundUpdateStore(SQLBaseStore):
+ REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
+ ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+ self.config = hs.config
+
+ self.db.updates.register_background_update_handler(
+ "insert_room_retention", self._background_insert_retention,
+ )
+
+ self.db.updates.register_background_update_handler(
+ self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
+ self._remove_tombstoned_rooms_from_directory,
+ )
+
+ self.db.updates.register_background_update_handler(
+ self.ADD_ROOMS_ROOM_VERSION_COLUMN,
+ self._background_add_rooms_room_version_column,
+ )
+
+ @defer.inlineCallbacks
+ def _background_insert_retention(self, progress, batch_size):
+ """Retrieves a list of all rooms within a range and inserts an entry for each of
+ them into the room_retention table.
+ NULLs the property's columns if missing from the retention event in the room's
+ state (or NULLs all of them if there's no retention event in the room's state),
+ so that we fall back to the server's retention policy.
+ """
+
+ last_room = progress.get("room_id", "")
+
+ def _background_insert_retention_txn(txn):
+ txn.execute(
+ """
+ SELECT state.room_id, state.event_id, events.json
+ FROM current_state_events as state
+ LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
+ WHERE state.room_id > ? AND state.type = '%s'
+ ORDER BY state.room_id ASC
+ LIMIT ?;
+ """
+ % EventTypes.Retention,
+ (last_room, batch_size),
+ )
+
+ rows = self.db.cursor_to_dict(txn)
+
+ if not rows:
+ return True
+
+ for row in rows:
+ if not row["json"]:
+ retention_policy = {}
+ else:
+ ev = json.loads(row["json"])
+ retention_policy = json.dumps(ev["content"])
+
+ self.db.simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": row["room_id"],
+ "event_id": row["event_id"],
+ "min_lifetime": retention_policy.get("min_lifetime"),
+ "max_lifetime": retention_policy.get("max_lifetime"),
+ },
+ )
+
+ logger.info("Inserted %d rows into room_retention", len(rows))
+
+ self.db.updates._background_update_progress_txn(
+ txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+ )
+
+ if batch_size > len(rows):
+ return True
+ else:
+ return False
+
+ end = yield self.db.runInteraction(
+ "insert_room_retention", _background_insert_retention_txn,
+ )
+
+ if end:
+ yield self.db.updates._end_background_update("insert_room_retention")
+
+ defer.returnValue(batch_size)
+
+ async def _background_add_rooms_room_version_column(
+ self, progress: dict, batch_size: int
+ ):
+ """Background update to go and add room version inforamtion to `rooms`
+ table from `current_state_events` table.
+ """
+
+ last_room_id = progress.get("room_id", "")
+
+ def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
+ sql = """
+ SELECT room_id, json FROM current_state_events
+ INNER JOIN event_json USING (room_id, event_id)
+ WHERE room_id > ? AND type = 'm.room.create' AND state_key = ''
+ ORDER BY room_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_room_id, batch_size))
+
+ updates = []
+ for room_id, event_json in txn:
+ event_dict = json.loads(event_json)
+ room_version_id = event_dict.get("content", {}).get(
+ "room_version", RoomVersions.V1.identifier
+ )
+
+ creator = event_dict.get("content").get("creator")
+
+ updates.append((room_id, creator, room_version_id))
+
+ if not updates:
+ return True
+
+ new_last_room_id = ""
+ for room_id, creator, room_version_id in updates:
+ # We upsert here just in case we don't already have a row,
+ # mainly for paranoia as much badness would happen if we don't
+ # insert the row and then try and get the room version for the
+ # room.
+ self.db.simple_upsert_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ values={"room_version": room_version_id},
+ insertion_values={"is_public": False, "creator": creator},
+ )
+ new_last_room_id = room_id
+
+ self.db.updates._background_update_progress_txn(
+ txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id}
+ )
+
+ return False
+
+ end = await self.db.runInteraction(
+ "_background_add_rooms_room_version_column",
+ _background_add_rooms_room_version_column_txn,
+ )
+
+ if end:
+ await self.db.updates._end_background_update(
+ self.ADD_ROOMS_ROOM_VERSION_COLUMN
+ )
+
+ return batch_size
+
+ async def _remove_tombstoned_rooms_from_directory(
+ self, progress, batch_size
+ ) -> int:
+ """Removes any rooms with tombstone events from the room directory
+
+ Nowadays this is handled by the room upgrade handler, but we may have some
+ that got left behind
+ """
+
+ last_room = progress.get("room_id", "")
+
+ def _get_rooms(txn):
+ txn.execute(
+ """
+ SELECT room_id
+ FROM rooms r
+ INNER JOIN current_state_events cse USING (room_id)
+ WHERE room_id > ? AND r.is_public
+ AND cse.type = '%s' AND cse.state_key = ''
+ ORDER BY room_id ASC
+ LIMIT ?;
+ """
+ % EventTypes.Tombstone,
+ (last_room, batch_size),
+ )
+
+ return [row[0] for row in txn]
+
+ rooms = await self.db.runInteraction(
+ "get_tombstoned_directory_rooms", _get_rooms
+ )
+
+ if not rooms:
+ await self.db.updates._end_background_update(
+ self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
+ )
+ return 0
+
+ for room_id in rooms:
+ logger.info("Removing tombstoned room %s from the directory", room_id)
+ await self.set_room_is_public(room_id, False)
+
+ await self.db.updates._background_update_progress(
+ self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
+ )
+
+ return len(rooms)
+
+ @abstractmethod
+ def set_room_is_public(self, room_id, is_public):
+ # this will need to be implemented if a background update is performed with
+ # existing (tombstoned, public) rooms in the database.
+ #
+ # It's overridden by RoomStore for the synapse master.
+ raise NotImplementedError()
+
+
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomStore, self).__init__(database, db_conn, hs)
+
+ self.config = hs.config
+
+ async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion):
+ """Ensure that the room is stored in the table
+
+ Called when we join a room over federation, and overwrites any room version
+ currently in the table.
+ """
+ await self.db.simple_upsert(
+ desc="upsert_room_on_join",
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ values={"room_version": room_version.identifier},
+ insertion_values={"is_public": False, "creator": ""},
+ # rooms has a unique constraint on room_id, so no need to lock when doing an
+ # emulated upsert.
+ lock=False,
+ )
+
+ @defer.inlineCallbacks
+ def store_room(
+ self,
+ room_id: str,
+ room_creator_user_id: str,
+ is_public: bool,
+ room_version: RoomVersion,
+ ):
+ """Stores a room.
+
+ Args:
+ room_id: The desired room ID, can be None.
+ room_creator_user_id: The user ID of the room creator.
+ is_public: True to indicate that this room should appear in
+ public room lists.
+ room_version: The version of the room
+ Raises:
+ StoreError if the room could not be stored.
+ """
+ try:
+
+ def store_room_txn(txn, next_id):
+ self.db.simple_insert_txn(
+ txn,
+ "rooms",
+ {
+ "room_id": room_id,
+ "creator": room_creator_user_id,
+ "is_public": is_public,
+ "room_version": room_version.identifier,
+ },
+ )
+ if is_public:
+ self.db.simple_insert_txn(
+ txn,
+ table="public_room_list_stream",
+ values={
+ "stream_id": next_id,
+ "room_id": room_id,
+ "visibility": is_public,
+ },
+ )
+
+ with self._public_room_id_gen.get_next() as next_id:
+ yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
+ except Exception as e:
+ logger.error("store_room with room_id=%s failed: %s", room_id, e)
+ raise StoreError(500, "Problem creating room.")
+
+ async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion):
+ """
+ When we receive an invite over federation, store the version of the room if we
+ don't already know the room version.
+ """
+ await self.db.simple_upsert(
+ desc="maybe_store_room_on_invite",
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ values={},
+ insertion_values={
+ "room_version": room_version.identifier,
+ "is_public": False,
+ "creator": "",
+ },
+ # rooms has a unique constraint on room_id, so no need to lock when doing an
+ # emulated upsert.
+ lock=False,
+ )
+
+ @defer.inlineCallbacks
+ def set_room_is_public(self, room_id, is_public):
+ def set_room_is_public_txn(txn, next_id):
+ self.db.simple_update_one_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"is_public": is_public},
+ )
+
+ entries = self.db.simple_select_list_txn(
+ txn,
+ table="public_room_list_stream",
+ keyvalues={
+ "room_id": room_id,
+ "appservice_id": None,
+ "network_id": None,
+ },
+ retcols=("stream_id", "visibility"),
+ )
+
+ entries.sort(key=lambda r: r["stream_id"])
+
+ add_to_stream = True
+ if entries:
+ add_to_stream = bool(entries[-1]["visibility"]) != is_public
+
+ if add_to_stream:
+ self.db.simple_insert_txn(
+ txn,
+ table="public_room_list_stream",
+ values={
+ "stream_id": next_id,
+ "room_id": room_id,
+ "visibility": is_public,
+ "appservice_id": None,
+ "network_id": None,
+ },
+ )
+
+ with self._public_room_id_gen.get_next() as next_id:
+ yield self.db.runInteraction(
+ "set_room_is_public", set_room_is_public_txn, next_id
+ )
+ self.hs.get_notifier().on_new_replication_data()
+
+ @defer.inlineCallbacks
+ def set_room_is_public_appservice(
+ self, room_id, appservice_id, network_id, is_public
+ ):
+ """Edit the appservice/network specific public room list.
+
+ Each appservice can have a number of published room lists associated
+ with them, keyed off of an appservice defined `network_id`, which
+ basically represents a single instance of a bridge to a third party
+ network.
+
+ Args:
+ room_id (str)
+ appservice_id (str)
+ network_id (str)
+ is_public (bool): Whether to publish or unpublish the room from the
+ list.
+ """
+
+ def set_room_is_public_appservice_txn(txn, next_id):
+ if is_public:
+ try:
+ self.db.simple_insert_txn(
+ txn,
+ table="appservice_room_list",
+ values={
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ "room_id": room_id,
+ },
+ )
+ except self.database_engine.module.IntegrityError:
+ # We've already inserted, nothing to do.
+ return
+ else:
+ self.db.simple_delete_txn(
+ txn,
+ table="appservice_room_list",
+ keyvalues={
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ "room_id": room_id,
+ },
+ )
+
+ entries = self.db.simple_select_list_txn(
+ txn,
+ table="public_room_list_stream",
+ keyvalues={
+ "room_id": room_id,
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ },
+ retcols=("stream_id", "visibility"),
+ )
+
+ entries.sort(key=lambda r: r["stream_id"])
+
+ add_to_stream = True
+ if entries:
+ add_to_stream = bool(entries[-1]["visibility"]) != is_public
+
+ if add_to_stream:
+ self.db.simple_insert_txn(
+ txn,
+ table="public_room_list_stream",
+ values={
+ "stream_id": next_id,
+ "room_id": room_id,
+ "visibility": is_public,
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ },
+ )
+
+ with self._public_room_id_gen.get_next() as next_id:
+ yield self.db.runInteraction(
+ "set_room_is_public_appservice",
+ set_room_is_public_appservice_txn,
+ next_id,
+ )
+ self.hs.get_notifier().on_new_replication_data()
+
+ def get_room_count(self):
+ """Retrieve a list of all rooms
+ """
+
+ def f(txn):
+ sql = "SELECT count(*) FROM rooms"
+ txn.execute(sql)
+ row = txn.fetchone()
+ return row[0] or 0
+
+ return self.db.runInteraction("get_rooms", f)
+
+ def _store_room_topic_txn(self, txn, event):
+ if hasattr(event, "content") and "topic" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.topic", event.content["topic"]
+ )
+
+ def _store_room_name_txn(self, txn, event):
+ if hasattr(event, "content") and "name" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.name", event.content["name"]
+ )
+
+ def _store_room_message_txn(self, txn, event):
+ if hasattr(event, "content") and "body" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.body", event.content["body"]
+ )
+
+ def _store_retention_policy_for_room_txn(self, txn, event):
+ if hasattr(event, "content") and (
+ "min_lifetime" in event.content or "max_lifetime" in event.content
+ ):
+ if (
+ "min_lifetime" in event.content
+ and not isinstance(event.content.get("min_lifetime"), integer_types)
+ ) or (
+ "max_lifetime" in event.content
+ and not isinstance(event.content.get("max_lifetime"), integer_types)
+ ):
+ # Ignore the event if one of the value isn't an integer.
+ return
+
+ self.db.simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ "min_lifetime": event.content.get("min_lifetime"),
+ "max_lifetime": event.content.get("max_lifetime"),
+ },
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_retention_policy_for_room, (event.room_id,)
+ )
+
+ def add_event_report(
+ self, room_id, event_id, user_id, reason, content, received_ts
+ ):
+ next_id = self._event_reports_id_gen.get_next()
+ return self.db.simple_insert(
+ table="event_reports",
+ values={
+ "id": next_id,
+ "received_ts": received_ts,
+ "room_id": room_id,
+ "event_id": event_id,
+ "user_id": user_id,
+ "reason": reason,
+ "content": json.dumps(content),
+ },
+ desc="add_event_report",
+ )
+
+ def get_current_public_room_stream_id(self):
+ return self._public_room_id_gen.get_current_token()
+
+ def get_all_new_public_rooms(self, prev_id, current_id, limit):
+ def get_all_new_public_rooms(txn):
+ sql = """
+ SELECT stream_id, room_id, visibility, appservice_id, network_id
+ FROM public_room_list_stream
+ WHERE stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (prev_id, current_id, limit))
+ return txn.fetchall()
+
+ if prev_id == current_id:
+ return defer.succeed([])
+
+ return self.db.runInteraction(
+ "get_all_new_public_rooms", get_all_new_public_rooms
+ )
+
+ @defer.inlineCallbacks
+ def block_room(self, room_id, user_id):
+ """Marks the room as blocked. Can be called multiple times.
+
+ Args:
+ room_id (str): Room to block
+ user_id (str): Who blocked it
+
+ Returns:
+ Deferred
+ """
+ yield self.db.simple_upsert(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ values={},
+ insertion_values={"user_id": user_id},
+ desc="block_room",
+ )
+ yield self.db.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked,
+ (room_id,),
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms_for_retention_period_in_range(
+ self, min_ms, max_ms, include_null=False
+ ):
+ """Retrieves all of the rooms within the given retention range.
+
+ Optionally includes the rooms which don't have a retention policy.
+
+ Args:
+ min_ms (int|None): Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, doesn't set a lower limit.
+ max_ms (int|None): Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, doesn't set an upper limit.
+ include_null (bool): Whether to include rooms which retention policy is NULL
+ in the returned set.
+
+ Returns:
+ dict[str, dict]: The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ """
+
+ def get_rooms_for_retention_period_in_range_txn(txn):
+ range_conditions = []
+ args = []
+
+ if min_ms is not None:
+ range_conditions.append("max_lifetime > ?")
+ args.append(min_ms)
+
+ if max_ms is not None:
+ range_conditions.append("max_lifetime <= ?")
+ args.append(max_ms)
+
+ # Do a first query which will retrieve the rooms that have a retention policy
+ # in their current state.
+ sql = """
+ SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ """
+
+ if len(range_conditions):
+ sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+ if include_null:
+ sql += " OR max_lifetime IS NULL"
+
+ txn.execute(sql, args)
+
+ rows = self.db.cursor_to_dict(txn)
+ rooms_dict = {}
+
+ for row in rows:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": row["min_lifetime"],
+ "max_lifetime": row["max_lifetime"],
+ }
+
+ if include_null:
+ # If required, do a second query that retrieves all of the rooms we know
+ # of so we can handle rooms with no retention policy.
+ sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+ txn.execute(sql)
+
+ rows = self.db.cursor_to_dict(txn)
+
+ # If a room isn't already in the dict (i.e. it doesn't have a retention
+ # policy in its state), add it with a null policy.
+ for row in rows:
+ if row["room_id"] not in rooms_dict:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": None,
+ "max_lifetime": None,
+ }
+
+ return rooms_dict
+
+ rooms = yield self.db.runInteraction(
+ "get_rooms_for_retention_period_in_range",
+ get_rooms_for_retention_period_in_range_txn,
+ )
+
+ defer.returnValue(rooms)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
new file mode 100644
index 0000000000..d5bd0cb5cf
--- /dev/null
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -0,0 +1,1265 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Iterable, List, Set
+
+from six import iteritems, itervalues
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import (
+ LoggingTransaction,
+ SQLBaseStore,
+ make_in_list_sql_clause,
+)
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
+from synapse.storage.engines import Sqlite3Engine
+from synapse.storage.roommember import (
+ GetRoomsForUserWithStreamOrdering,
+ MemberSummary,
+ ProfileInfo,
+ RoomsForUser,
+)
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.metrics import Measure
+from synapse.util.stringutils import to_ascii
+
+logger = logging.getLogger(__name__)
+
+
+_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
+_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
+
+
+class RoomMemberWorkerStore(EventsWorkerStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
+
+ # Is the current_state_events.membership up to date? Or is the
+ # background update still running?
+ self._current_state_events_membership_up_to_date = False
+
+ txn = LoggingTransaction(
+ db_conn.cursor(),
+ name="_check_safe_current_state_events_membership_updated",
+ database_engine=self.database_engine,
+ )
+ self._check_safe_current_state_events_membership_updated_txn(txn)
+ txn.close()
+
+ if self.hs.config.metrics_flags.known_servers:
+ self._known_servers_count = 1
+ self.hs.get_clock().looping_call(
+ run_as_background_process,
+ 60 * 1000,
+ "_count_known_servers",
+ self._count_known_servers,
+ )
+ self.hs.get_clock().call_later(
+ 1000,
+ run_as_background_process,
+ "_count_known_servers",
+ self._count_known_servers,
+ )
+ LaterGauge(
+ "synapse_federation_known_servers",
+ "",
+ [],
+ lambda: self._known_servers_count,
+ )
+
+ @defer.inlineCallbacks
+ def _count_known_servers(self):
+ """
+ Count the servers that this server knows about.
+
+ The statistic is stored on the class for the
+ `synapse_federation_known_servers` LaterGauge to collect.
+ """
+
+ def _transact(txn):
+ if isinstance(self.database_engine, Sqlite3Engine):
+ query = """
+ SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
+ FROM (
+ SELECT rm.user_id as user_id, instr(rm.user_id, ':')
+ AS pos FROM room_memberships as rm
+ INNER JOIN current_state_events as c ON rm.event_id = c.event_id
+ WHERE c.type = 'm.room.member'
+ ) as out
+ """
+ else:
+ query = """
+ SELECT COUNT(DISTINCT split_part(state_key, ':', 2))
+ FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join';
+ """
+ txn.execute(query)
+ return list(txn)[0][0]
+
+ count = yield self.db.runInteraction("get_known_servers", _transact)
+
+ # We always know about ourselves, even if we have nothing in
+ # room_memberships (for example, the server is new).
+ self._known_servers_count = max([count, 1])
+ return self._known_servers_count
+
+ def _check_safe_current_state_events_membership_updated_txn(self, txn):
+ """Checks if it is safe to assume the new current_state_events
+ membership column is up to date
+ """
+
+ pending_update = self.db.simple_select_one_txn(
+ txn,
+ table="background_updates",
+ keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
+ retcols=["update_name"],
+ allow_none=True,
+ )
+
+ self._current_state_events_membership_up_to_date = not pending_update
+
+ # If the update is still running, reschedule to run.
+ if pending_update:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "_check_safe_current_state_events_membership_updated",
+ self.db.runInteraction,
+ "_check_safe_current_state_events_membership_updated",
+ self._check_safe_current_state_events_membership_updated_txn,
+ )
+
+ @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
+ def get_hosts_in_room(self, room_id, cache_context):
+ """Returns the set of all hosts currently in the room
+ """
+ user_ids = yield self.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
+ return hosts
+
+ @cached(max_entries=100000, iterable=True)
+ def get_users_in_room(self, room_id):
+ return self.db.runInteraction(
+ "get_users_in_room", self.get_users_in_room_txn, room_id
+ )
+
+ def get_users_in_room_txn(self, txn, room_id):
+ # If we can assume current_state_events.membership is up to date
+ # then we can avoid a join, which is a Very Good Thing given how
+ # frequently this function gets called.
+ if self._current_state_events_membership_up_to_date:
+ sql = """
+ SELECT state_key FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+ """
+ else:
+ sql = """
+ SELECT state_key FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+ """
+
+ txn.execute(sql, (room_id, Membership.JOIN))
+ return [to_ascii(r[0]) for r in txn]
+
+ @cached(max_entries=100000)
+ def get_room_summary(self, room_id):
+ """ Get the details of a room roughly suitable for use by the room
+ summary extension to /sync. Useful when lazy loading room members.
+ Args:
+ room_id (str): The room ID to query
+ Returns:
+ Deferred[dict[str, MemberSummary]:
+ dict of membership states, pointing to a MemberSummary named tuple.
+ """
+
+ def _get_room_summary_txn(txn):
+ # first get counts.
+ # We do this all in one transaction to keep the cache small.
+ # FIXME: get rid of this when we have room_stats
+
+ # If we can assume current_state_events.membership is up to date
+ # then we can avoid a join, which is a Very Good Thing given how
+ # frequently this function gets called.
+ if self._current_state_events_membership_up_to_date:
+ # Note, rejected events will have a null membership field, so
+ # we we manually filter them out.
+ sql = """
+ SELECT count(*), membership FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ?
+ AND membership IS NOT NULL
+ GROUP BY membership
+ """
+ else:
+ sql = """
+ SELECT count(*), m.membership FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
+
+ txn.execute(sql, (room_id,))
+ res = {}
+ for count, membership in txn:
+ summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
+
+ # we order by membership and then fairly arbitrarily by event_id so
+ # heroes are consistent
+ if self._current_state_events_membership_up_to_date:
+ # Note, rejected events will have a null membership field, so
+ # we we manually filter them out.
+ sql = """
+ SELECT state_key, membership, event_id
+ FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ?
+ AND membership IS NOT NULL
+ ORDER BY
+ CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+ event_id ASC
+ LIMIT ?
+ """
+ else:
+ sql = """
+ SELECT c.state_key, m.membership, c.event_id
+ FROM room_memberships as m
+ INNER JOIN current_state_events as c USING (room_id, event_id)
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ ORDER BY
+ CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+ c.event_id ASC
+ LIMIT ?
+ """
+
+ # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
+ txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
+ for user_id, membership, event_id in txn:
+ summary = res[to_ascii(membership)]
+ # we will always have a summary for this membership type at this
+ # point given the summary currently contains the counts.
+ members = summary.members
+ members.append((to_ascii(user_id), to_ascii(event_id)))
+
+ return res
+
+ return self.db.runInteraction("get_room_summary", _get_room_summary_txn)
+
+ def _get_user_counts_in_room_txn(self, txn, room_id):
+ """
+ Get the user count in a room by membership.
+
+ Args:
+ room_id (str)
+ membership (Membership)
+
+ Returns:
+ Deferred[int]
+ """
+ sql = """
+ SELECT m.membership, count(*) FROM room_memberships as m
+ INNER JOIN current_state_events as c USING(event_id)
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
+
+ txn.execute(sql, (room_id,))
+ return {row[0]: row[1] for row in txn}
+
+ @cached()
+ def get_invited_rooms_for_local_user(self, user_id):
+ """ Get all the rooms the *local* user is invited to
+
+ Args:
+ user_id (str): The user ID.
+ Returns:
+ A deferred list of RoomsForUser.
+ """
+
+ return self.get_rooms_for_local_user_where_membership_is(
+ user_id, [Membership.INVITE]
+ )
+
+ @defer.inlineCallbacks
+ def get_invite_for_local_user_in_room(self, user_id, room_id):
+ """Gets the invite for the given *local* user and room
+
+ Args:
+ user_id (str)
+ room_id (str)
+
+ Returns:
+ Deferred: Resolves to either a RoomsForUser or None if no invite was
+ found.
+ """
+ invites = yield self.get_invited_rooms_for_local_user(user_id)
+ for invite in invites:
+ if invite.room_id == room_id:
+ return invite
+ return None
+
+ @defer.inlineCallbacks
+ def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
+ """ Get all the rooms for this *local* user where the membership for this user
+ matches one in the membership list.
+
+ Filters out forgotten rooms.
+
+ Args:
+ user_id (str): The user ID.
+ membership_list (list): A list of synapse.api.constants.Membership
+ values which the user must be in.
+
+ Returns:
+ Deferred[list[RoomsForUser]]
+ """
+ if not membership_list:
+ return defer.succeed(None)
+
+ rooms = yield self.db.runInteraction(
+ "get_rooms_for_local_user_where_membership_is",
+ self._get_rooms_for_local_user_where_membership_is_txn,
+ user_id,
+ membership_list,
+ )
+
+ # Now we filter out forgotten rooms
+ forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
+ return [room for room in rooms if room.room_id not in forgotten_rooms]
+
+ def _get_rooms_for_local_user_where_membership_is_txn(
+ self, txn, user_id, membership_list
+ ):
+ # Paranoia check.
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
+ % (user_id,),
+ )
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "c.membership", membership_list
+ )
+
+ sql = """
+ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+ FROM local_current_membership AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ user_id = ?
+ AND %s
+ """ % (
+ clause,
+ )
+
+ txn.execute(sql, (user_id, *args))
+ results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
+
+ return results
+
+ @cached(max_entries=500000, iterable=True)
+ def get_rooms_for_user_with_stream_ordering(self, user_id):
+ """Returns a set of room_ids the user is currently joined to.
+
+ If a remote user only returns rooms this server is currently
+ participating in.
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+ the rooms the user is in currently, along with the stream ordering
+ of the most recent join for that user and room.
+ """
+ return self.db.runInteraction(
+ "get_rooms_for_user_with_stream_ordering",
+ self._get_rooms_for_user_with_stream_ordering_txn,
+ user_id,
+ )
+
+ def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+ # We use `current_state_events` here and not `local_current_membership`
+ # as a) this gets called with remote users and b) this only gets called
+ # for rooms the server is participating in.
+ if self._current_state_events_membership_up_to_date:
+ sql = """
+ SELECT room_id, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND state_key = ?
+ AND c.membership = ?
+ """
+ else:
+ sql = """
+ SELECT room_id, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN room_memberships AS m USING (room_id, event_id)
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND state_key = ?
+ AND m.membership = ?
+ """
+
+ txn.execute(sql, (user_id, Membership.JOIN))
+ results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+
+ return results
+
+ async def get_users_server_still_shares_room_with(
+ self, user_ids: Collection[str]
+ ) -> Set[str]:
+ """Given a list of users return the set that the server still share a
+ room with.
+ """
+
+ if not user_ids:
+ return set()
+
+ def _get_users_server_still_shares_room_with_txn(txn):
+ sql = """
+ SELECT state_key FROM current_state_events
+ WHERE
+ type = 'm.room.member'
+ AND membership = 'join'
+ AND %s
+ GROUP BY state_key
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "state_key", user_ids
+ )
+
+ txn.execute(sql % (clause,), args)
+
+ return {row[0] for row in txn}
+
+ return await self.db.runInteraction(
+ "get_users_server_still_shares_room_with",
+ _get_users_server_still_shares_room_with_txn,
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms_for_user(self, user_id, on_invalidate=None):
+ """Returns a set of room_ids the user is currently joined to.
+
+ If a remote user only returns rooms this server is currently
+ participating in.
+ """
+ rooms = yield self.get_rooms_for_user_with_stream_ordering(
+ user_id, on_invalidate=on_invalidate
+ )
+ return frozenset(r.room_id for r in rooms)
+
+ @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
+ def get_users_who_share_room_with_user(self, user_id, cache_context):
+ """Returns the set of users who share a room with `user_id`
+ """
+ room_ids = yield self.get_rooms_for_user(
+ user_id, on_invalidate=cache_context.invalidate
+ )
+
+ user_who_share_room = set()
+ for room_id in room_ids:
+ user_ids = yield self.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ user_who_share_room.update(user_ids)
+
+ return user_who_share_room
+
+ @defer.inlineCallbacks
+ def get_joined_users_from_context(self, event, context):
+ state_group = context.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ current_state_ids = yield context.get_current_state_ids()
+ result = yield self._get_joined_users_from_context(
+ event.room_id, state_group, current_state_ids, event=event, context=context
+ )
+ return result
+
+ @defer.inlineCallbacks
+ def get_joined_users_from_state(self, room_id, state_entry):
+ state_group = state_entry.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ with Measure(self._clock, "get_joined_users_from_state"):
+ return (
+ yield self._get_joined_users_from_context(
+ room_id, state_group, state_entry.state, context=state_entry
+ )
+ )
+
+ @cachedInlineCallbacks(
+ num_args=2, cache_context=True, iterable=True, max_entries=100000
+ )
+ def _get_joined_users_from_context(
+ self,
+ room_id,
+ state_group,
+ current_state_ids,
+ cache_context,
+ event=None,
+ context=None,
+ ):
+ # We don't use `state_group`, it's there so that we can cache based
+ # on it. However, it's important that it's never None, since two current_states
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ users_in_room = {}
+ member_event_ids = [
+ e_id
+ for key, e_id in iteritems(current_state_ids)
+ if key[0] == EventTypes.Member
+ ]
+
+ if context is not None:
+ # If we have a context with a delta from a previous state group,
+ # check if we also have the result from the previous group in cache.
+ # If we do then we can reuse that result and simply update it with
+ # any membership changes in `delta_ids`
+ if context.prev_group and context.delta_ids:
+ prev_res = self._get_joined_users_from_context.cache.get(
+ (room_id, context.prev_group), None
+ )
+ if prev_res and isinstance(prev_res, dict):
+ users_in_room = dict(prev_res)
+ member_event_ids = [
+ e_id
+ for key, e_id in iteritems(context.delta_ids)
+ if key[0] == EventTypes.Member
+ ]
+ for etype, state_key in context.delta_ids:
+ users_in_room.pop(state_key, None)
+
+ # We check if we have any of the member event ids in the event cache
+ # before we ask the DB
+
+ # We don't update the event cache hit ratio as it completely throws off
+ # the hit ratio counts. After all, we don't populate the cache if we
+ # miss it here
+ event_map = self._get_events_from_cache(
+ member_event_ids, allow_rejected=False, update_metrics=False
+ )
+
+ missing_member_event_ids = []
+ for event_id in member_event_ids:
+ ev_entry = event_map.get(event_id)
+ if ev_entry:
+ if ev_entry.event.membership == Membership.JOIN:
+ users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
+ display_name=to_ascii(
+ ev_entry.event.content.get("displayname", None)
+ ),
+ avatar_url=to_ascii(
+ ev_entry.event.content.get("avatar_url", None)
+ ),
+ )
+ else:
+ missing_member_event_ids.append(event_id)
+
+ if missing_member_event_ids:
+ event_to_memberships = yield self._get_joined_profiles_from_event_ids(
+ missing_member_event_ids
+ )
+ users_in_room.update((row for row in event_to_memberships.values() if row))
+
+ if event is not None and event.type == EventTypes.Member:
+ if event.membership == Membership.JOIN:
+ if event.event_id in member_event_ids:
+ users_in_room[to_ascii(event.state_key)] = ProfileInfo(
+ display_name=to_ascii(event.content.get("displayname", None)),
+ avatar_url=to_ascii(event.content.get("avatar_url", None)),
+ )
+
+ return users_in_room
+
+ @cached(max_entries=10000)
+ def _get_joined_profile_from_event_id(self, event_id):
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_joined_profile_from_event_id",
+ list_name="event_ids",
+ inlineCallbacks=True,
+ )
+ def _get_joined_profiles_from_event_ids(self, event_ids):
+ """For given set of member event_ids check if they point to a join
+ event and if so return the associated user and profile info.
+
+ Args:
+ event_ids (Iterable[str]): The member event IDs to lookup
+
+ Returns:
+ Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+ to `user_id` and ProfileInfo (or None if not join event).
+ """
+
+ rows = yield self.db.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("user_id", "display_name", "avatar_url", "event_id"),
+ keyvalues={"membership": Membership.JOIN},
+ batch_size=500,
+ desc="_get_membership_from_event_ids",
+ )
+
+ return {
+ row["event_id"]: (
+ row["user_id"],
+ ProfileInfo(
+ avatar_url=row["avatar_url"], display_name=row["display_name"]
+ ),
+ )
+ for row in rows
+ }
+
+ @cachedInlineCallbacks(max_entries=10000)
+ def is_host_joined(self, room_id, host):
+ if "%" in host or "_" in host:
+ raise Exception("Invalid host name")
+
+ sql = """
+ SELECT state_key FROM current_state_events AS c
+ INNER JOIN room_memberships AS m USING (event_id)
+ WHERE m.membership = 'join'
+ AND type = 'm.room.member'
+ AND c.room_id = ?
+ AND state_key LIKE ?
+ LIMIT 1
+ """
+
+ # We do need to be careful to ensure that host doesn't have any wild cards
+ # in it, but we checked above for known ones and we'll check below that
+ # the returned user actually has the correct domain.
+ like_clause = "%:" + host
+
+ rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause)
+
+ if not rows:
+ return False
+
+ user_id = rows[0][0]
+ if get_domain_from_id(user_id) != host:
+ # This can only happen if the host name has something funky in it
+ raise Exception("Invalid host name")
+
+ return True
+
+ @cachedInlineCallbacks()
+ def was_host_joined(self, room_id, host):
+ """Check whether the server is or ever was in the room.
+
+ Args:
+ room_id (str)
+ host (str)
+
+ Returns:
+ Deferred: Resolves to True if the host is/was in the room, otherwise
+ False.
+ """
+ if "%" in host or "_" in host:
+ raise Exception("Invalid host name")
+
+ sql = """
+ SELECT user_id FROM room_memberships
+ WHERE room_id = ?
+ AND user_id LIKE ?
+ AND membership = 'join'
+ LIMIT 1
+ """
+
+ # We do need to be careful to ensure that host doesn't have any wild cards
+ # in it, but we checked above for known ones and we'll check below that
+ # the returned user actually has the correct domain.
+ like_clause = "%:" + host
+
+ rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause)
+
+ if not rows:
+ return False
+
+ user_id = rows[0][0]
+ if get_domain_from_id(user_id) != host:
+ # This can only happen if the host name has something funky in it
+ raise Exception("Invalid host name")
+
+ return True
+
+ @defer.inlineCallbacks
+ def get_joined_hosts(self, room_id, state_entry):
+ state_group = state_entry.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ with Measure(self._clock, "get_joined_hosts"):
+ return (
+ yield self._get_joined_hosts(
+ room_id, state_group, state_entry.state, state_entry=state_entry
+ )
+ )
+
+ @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
+ # @defer.inlineCallbacks
+ def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ cache = yield self._get_joined_hosts_cache(room_id)
+ joined_hosts = yield cache.get_destinations(state_entry)
+
+ return joined_hosts
+
+ @cached(max_entries=10000)
+ def _get_joined_hosts_cache(self, room_id):
+ return _JoinedHostsCache(self, room_id)
+
+ @cachedInlineCallbacks(num_args=2)
+ def did_forget(self, user_id, room_id):
+ """Returns whether user_id has elected to discard history for room_id.
+
+ Returns False if they have since re-joined."""
+
+ def f(txn):
+ sql = (
+ "SELECT"
+ " COUNT(*)"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " forgotten = 0"
+ )
+ txn.execute(sql, (user_id, room_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+
+ count = yield self.db.runInteraction("did_forget_membership", f)
+ return count == 0
+
+ @cached()
+ def get_forgotten_rooms_for_user(self, user_id):
+ """Gets all rooms the user has forgotten.
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[set[str]]
+ """
+
+ def _get_forgotten_rooms_for_user_txn(txn):
+ # This is a slightly convoluted query that first looks up all rooms
+ # that the user has forgotten in the past, then rechecks that list
+ # to see if any have subsequently been updated. This is done so that
+ # we can use a partial index on `forgotten = 1` on the assumption
+ # that few users will actually forget many rooms.
+ #
+ # Note that a room is considered "forgotten" if *all* membership
+ # events for that user and room have the forgotten field set (as
+ # when a user forgets a room we update all rows for that user and
+ # room, not just the current one).
+ sql = """
+ SELECT room_id, (
+ SELECT count(*) FROM room_memberships
+ WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0
+ ) AS count
+ FROM room_memberships AS m
+ WHERE user_id = ? AND forgotten = 1
+ GROUP BY room_id, user_id;
+ """
+ txn.execute(sql, (user_id,))
+ return {row[0] for row in txn if row[1] == 0}
+
+ return self.db.runInteraction(
+ "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms_user_has_been_in(self, user_id):
+ """Get all rooms that the user has ever been in.
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[set[str]]: Set of room IDs.
+ """
+
+ room_ids = yield self.db.simple_select_onecol(
+ table="room_memberships",
+ keyvalues={"membership": Membership.JOIN, "user_id": user_id},
+ retcol="room_id",
+ desc="get_rooms_user_has_been_in",
+ )
+
+ return set(room_ids)
+
+ def get_membership_from_event_ids(
+ self, member_event_ids: Iterable[str]
+ ) -> List[dict]:
+ """Get user_id and membership of a set of event IDs.
+ """
+
+ return self.db.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids,
+ retcols=("user_id", "membership", "event_id"),
+ keyvalues={},
+ batch_size=500,
+ desc="get_membership_from_event_ids",
+ )
+
+ async def is_local_host_in_room_ignoring_users(
+ self, room_id: str, ignore_users: Collection[str]
+ ) -> bool:
+ """Check if there are any local users, excluding those in the given
+ list, in the room.
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "user_id", ignore_users
+ )
+
+ sql = """
+ SELECT 1 FROM local_current_membership
+ WHERE
+ room_id = ? AND membership = ?
+ AND NOT (%s)
+ LIMIT 1
+ """ % (
+ clause,
+ )
+
+ def _is_local_host_in_room_ignoring_users_txn(txn):
+ txn.execute(sql, (room_id, Membership.JOIN, *args))
+
+ return bool(txn.fetchone())
+
+ return await self.db.runInteraction(
+ "is_local_host_in_room_ignoring_users",
+ _is_local_host_in_room_ignoring_users_txn,
+ )
+
+
+class RoomMemberBackgroundUpdateStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ self.db.updates.register_background_update_handler(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+ )
+ self.db.updates.register_background_update_handler(
+ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+ self._background_current_state_membership,
+ )
+ self.db.updates.register_background_index_update(
+ "room_membership_forgotten_idx",
+ index_name="room_memberships_user_room_forgotten",
+ table="room_memberships",
+ columns=["user_id", "room_id"],
+ where_clause="forgotten = 1",
+ )
+
+ @defer.inlineCallbacks
+ def _background_add_membership_profile(self, progress, batch_size):
+ target_min_stream_id = progress.get(
+ "target_min_stream_id_inclusive", self._min_stream_order_on_start
+ )
+ max_stream_id = progress.get(
+ "max_stream_id_exclusive", self._stream_order_on_start + 1
+ )
+
+ INSERT_CLUMP_SIZE = 1000
+
+ def add_membership_profile_txn(txn):
+ sql = """
+ SELECT stream_ordering, event_id, events.room_id, event_json.json
+ FROM events
+ INNER JOIN event_json USING (event_id)
+ INNER JOIN room_memberships USING (event_id)
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND type = 'm.room.member'
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+ rows = self.db.cursor_to_dict(txn)
+ if not rows:
+ return 0
+
+ min_stream_id = rows[-1]["stream_ordering"]
+
+ to_update = []
+ for row in rows:
+ event_id = row["event_id"]
+ room_id = row["room_id"]
+ try:
+ event_json = json.loads(row["json"])
+ content = event_json["content"]
+ except Exception:
+ continue
+
+ display_name = content.get("displayname", None)
+ avatar_url = content.get("avatar_url", None)
+
+ if display_name or avatar_url:
+ to_update.append((display_name, avatar_url, event_id, room_id))
+
+ to_update_sql = """
+ UPDATE room_memberships SET display_name = ?, avatar_url = ?
+ WHERE event_id = ? AND room_id = ?
+ """
+ for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
+ clump = to_update[index : index + INSERT_CLUMP_SIZE]
+ txn.executemany(to_update_sql, clump)
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ }
+
+ self.db.updates._background_update_progress_txn(
+ txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
+ )
+
+ return len(rows)
+
+ result = yield self.db.runInteraction(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
+ )
+
+ if not result:
+ yield self.db.updates._end_background_update(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME
+ )
+
+ return result
+
+ @defer.inlineCallbacks
+ def _background_current_state_membership(self, progress, batch_size):
+ """Update the new membership column on current_state_events.
+
+ This works by iterating over all rooms in alphebetical order.
+ """
+
+ def _background_current_state_membership_txn(txn, last_processed_room):
+ processed = 0
+ while processed < batch_size:
+ txn.execute(
+ """
+ SELECT MIN(room_id) FROM current_state_events WHERE room_id > ?
+ """,
+ (last_processed_room,),
+ )
+ row = txn.fetchone()
+ if not row or not row[0]:
+ return processed, True
+
+ (next_room,) = row
+
+ sql = """
+ UPDATE current_state_events
+ SET membership = (
+ SELECT membership FROM room_memberships
+ WHERE event_id = current_state_events.event_id
+ )
+ WHERE room_id = ?
+ """
+ txn.execute(sql, (next_room,))
+ processed += txn.rowcount
+
+ last_processed_room = next_room
+
+ self.db.updates._background_update_progress_txn(
+ txn,
+ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+ {"last_processed_room": last_processed_room},
+ )
+
+ return processed, False
+
+ # If we haven't got a last processed room then just use the empty
+ # string, which will compare before all room IDs correctly.
+ last_processed_room = progress.get("last_processed_room", "")
+
+ row_count, finished = yield self.db.runInteraction(
+ "_background_current_state_membership_update",
+ _background_current_state_membership_txn,
+ last_processed_room,
+ )
+
+ if finished:
+ yield self.db.updates._end_background_update(
+ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
+ )
+
+ return row_count
+
+
+class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomMemberStore, self).__init__(database, db_conn, hs)
+
+ def _store_room_members_txn(self, txn, events, backfilled):
+ """Store a room member in the database.
+ """
+ self.db.simple_insert_many_txn(
+ txn,
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
+ }
+ for event in events
+ ],
+ )
+
+ for event in events:
+ txn.call_after(
+ self._membership_stream_cache.entity_has_changed,
+ event.state_key,
+ event.internal_metadata.stream_ordering,
+ )
+ txn.call_after(
+ self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
+ )
+
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened. If the event is an
+ # outlier it is only current if its an "out of band membership",
+ # like a remote invite or a rejection of a remote invite.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_out_of_band_membership()
+ )
+ is_mine = self.hs.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self.db.simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ },
+ )
+ else:
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ txn.execute(
+ sql,
+ (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ),
+ )
+
+ # We also update the `local_current_membership` table with
+ # latest invite info. This will usually get updated by the
+ # `current_state_events` handling, unless its an outlier.
+ if event.internal_metadata.is_outlier():
+ # This should only happen for out of band memberships, so
+ # we add a paranoia check.
+ assert event.internal_metadata.is_out_of_band_membership()
+
+ self.db.simple_upsert_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={
+ "room_id": event.room_id,
+ "user_id": event.state_key,
+ },
+ values={
+ "event_id": event.event_id,
+ "membership": event.membership,
+ },
+ )
+
+ @defer.inlineCallbacks
+ def locally_reject_invite(self, user_id, room_id):
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ def f(txn, stream_ordering):
+ txn.execute(sql, (stream_ordering, True, room_id, user_id))
+
+ # We also clear this entry from `local_current_membership`.
+ # Ideally we'd point to a leave event, but we don't have one, so
+ # nevermind.
+ self.db.simple_delete_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ )
+
+ with self._stream_id_gen.get_next() as stream_ordering:
+ yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
+
+ def forget(self, user_id, room_id):
+ """Indicate that user_id wishes to discard history for room_id."""
+
+ def f(txn):
+ sql = (
+ "UPDATE"
+ " room_memberships"
+ " SET"
+ " forgotten = 1"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ )
+ txn.execute(sql, (user_id, room_id))
+
+ self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
+ self._invalidate_cache_and_stream(
+ txn, self.get_forgotten_rooms_for_user, (user_id,)
+ )
+
+ return self.db.runInteraction("forget_membership", f)
+
+
+class _JoinedHostsCache(object):
+ """Cache for joined hosts in a room that is optimised to handle updates
+ via state deltas.
+ """
+
+ def __init__(self, store, room_id):
+ self.store = store
+ self.room_id = room_id
+
+ self.hosts_to_joined_users = {}
+
+ self.state_group = object()
+
+ self.linearizer = Linearizer("_JoinedHostsCache")
+
+ self._len = 0
+
+ @defer.inlineCallbacks
+ def get_destinations(self, state_entry):
+ """Get set of destinations for a state entry
+
+ Args:
+ state_entry(synapse.state._StateCacheEntry)
+ """
+ if state_entry.state_group == self.state_group:
+ return frozenset(self.hosts_to_joined_users)
+
+ with (yield self.linearizer.queue(())):
+ if state_entry.state_group == self.state_group:
+ pass
+ elif state_entry.prev_group == self.state_group:
+ for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
+ if typ != EventTypes.Member:
+ continue
+
+ host = intern_string(get_domain_from_id(state_key))
+ user_id = state_key
+ known_joins = self.hosts_to_joined_users.setdefault(host, set())
+
+ event = yield self.store.get_event(event_id)
+ if event.membership == Membership.JOIN:
+ known_joins.add(user_id)
+ else:
+ known_joins.discard(user_id)
+
+ if not known_joins:
+ self.hosts_to_joined_users.pop(host, None)
+ else:
+ joined_users = yield self.store.get_joined_users_from_state(
+ self.room_id, state_entry
+ )
+
+ self.hosts_to_joined_users = {}
+ for user_id in joined_users:
+ host = intern_string(get_domain_from_id(user_id))
+ self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+ if state_entry.state_group:
+ self.state_group = state_entry.state_group
+ else:
+ self.state_group = object()
+ self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
+ return frozenset(self.hosts_to_joined_users)
+
+ def __len__(self):
+ return self._len
diff --git a/synapse/storage/schema/delta/12/v12.sql b/synapse/storage/data_stores/main/schema/delta/12/v12.sql
index 5964c5aaac..5964c5aaac 100644
--- a/synapse/storage/schema/delta/12/v12.sql
+++ b/synapse/storage/data_stores/main/schema/delta/12/v12.sql
diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/data_stores/main/schema/delta/13/v13.sql
index f8649e5d99..f8649e5d99 100644
--- a/synapse/storage/schema/delta/13/v13.sql
+++ b/synapse/storage/data_stores/main/schema/delta/13/v13.sql
diff --git a/synapse/storage/schema/delta/14/v14.sql b/synapse/storage/data_stores/main/schema/delta/14/v14.sql
index a831920da6..a831920da6 100644
--- a/synapse/storage/schema/delta/14/v14.sql
+++ b/synapse/storage/data_stores/main/schema/delta/14/v14.sql
diff --git a/synapse/storage/schema/delta/15/appservice_txns.sql b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql
index e4f5e76aec..e4f5e76aec 100644
--- a/synapse/storage/schema/delta/15/appservice_txns.sql
+++ b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql
diff --git a/synapse/storage/schema/delta/15/presence_indices.sql b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql
index 6b8d0f1ca7..6b8d0f1ca7 100644
--- a/synapse/storage/schema/delta/15/presence_indices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql
diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/data_stores/main/schema/delta/15/v15.sql
index 9523d2bcc3..9523d2bcc3 100644
--- a/synapse/storage/schema/delta/15/v15.sql
+++ b/synapse/storage/data_stores/main/schema/delta/15/v15.sql
diff --git a/synapse/storage/schema/delta/16/events_order_index.sql b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql
index a48f215170..a48f215170 100644
--- a/synapse/storage/schema/delta/16/events_order_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql
diff --git a/synapse/storage/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql
index 7a15265cb1..7a15265cb1 100644
--- a/synapse/storage/schema/delta/16/remote_media_cache_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql
diff --git a/synapse/storage/schema/delta/16/remove_duplicates.sql b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql
index 65c97b5e2f..65c97b5e2f 100644
--- a/synapse/storage/schema/delta/16/remove_duplicates.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql
diff --git a/synapse/storage/schema/delta/16/room_alias_index.sql b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql
index f82486132b..f82486132b 100644
--- a/synapse/storage/schema/delta/16/room_alias_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql
diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql
index 5b8de52c33..5b8de52c33 100644
--- a/synapse/storage/schema/delta/16/unique_constraints.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql
diff --git a/synapse/storage/schema/delta/16/users.sql b/synapse/storage/data_stores/main/schema/delta/16/users.sql
index cd0709250d..cd0709250d 100644
--- a/synapse/storage/schema/delta/16/users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/users.sql
diff --git a/synapse/storage/schema/delta/17/drop_indexes.sql b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql
index 7c9a90e27f..7c9a90e27f 100644
--- a/synapse/storage/schema/delta/17/drop_indexes.sql
+++ b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql
diff --git a/synapse/storage/schema/delta/17/server_keys.sql b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql
index 70b247a06b..70b247a06b 100644
--- a/synapse/storage/schema/delta/17/server_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql
diff --git a/synapse/storage/schema/delta/17/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql
index c17715ac80..c17715ac80 100644
--- a/synapse/storage/schema/delta/17/user_threepids.sql
+++ b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql
diff --git a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql
index 6e0871c92b..6e0871c92b 100644
--- a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql
+++ b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql
diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql
index 18b97b4332..18b97b4332 100644
--- a/synapse/storage/schema/delta/19/event_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql
diff --git a/synapse/storage/schema/delta/20/dummy.sql b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql
index e0ac49d1ec..e0ac49d1ec 100644
--- a/synapse/storage/schema/delta/20/dummy.sql
+++ b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql
diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/data_stores/main/schema/delta/20/pushers.py
index 147496a38b..3edfcfd783 100644
--- a/synapse/storage/schema/delta/20/pushers.py
+++ b/synapse/storage/data_stores/main/schema/delta/20/pushers.py
@@ -29,7 +29,8 @@ logger = logging.getLogger(__name__)
def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...")
- cur.execute("""
+ cur.execute(
+ """
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
@@ -48,27 +49,34 @@ def run_create(cur, database_engine, *args, **kwargs):
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
- """)
- cur.execute("""SELECT
+ """
+ )
+ cur.execute(
+ """SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
- """)
+ """
+ )
count = 0
for row in cur.fetchall():
row = list(row)
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
- cur.execute(database_engine.convert_param_style("""
+ cur.execute(
+ database_engine.convert_param_style(
+ """
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
- ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
- row
+ ) values (%s)"""
+ % (",".join(["?" for _ in range(len(row))]))
+ ),
+ row,
)
count += 1
cur.execute("DROP TABLE pushers")
diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql
index 4c2fb20b77..4c2fb20b77 100644
--- a/synapse/storage/schema/delta/21/end_to_end_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql
diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql
index d070845477..d070845477 100644
--- a/synapse/storage/schema/delta/21/receipts.sql
+++ b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql
diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql
index bfc0b3bcaa..bfc0b3bcaa 100644
--- a/synapse/storage/schema/delta/22/receipts_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql
diff --git a/synapse/storage/schema/delta/22/user_threepids_unique.sql b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql
index 87edfa454c..87edfa454c 100644
--- a/synapse/storage/schema/delta/22/user_threepids_unique.sql
+++ b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql
diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql
index acea7483bd..acea7483bd 100644
--- a/synapse/storage/schema/delta/24/stats_reporting.sql
+++ b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py
index 4b2ffd35fd..4b2ffd35fd 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/data_stores/main/schema/delta/25/fts.py
diff --git a/synapse/storage/schema/delta/25/guest_access.sql b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql
index 1ea389b471..1ea389b471 100644
--- a/synapse/storage/schema/delta/25/guest_access.sql
+++ b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql
diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql
index f468fc1897..f468fc1897 100644
--- a/synapse/storage/schema/delta/25/history_visibility.sql
+++ b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql
diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/data_stores/main/schema/delta/25/tags.sql
index 7a32ce68e4..7a32ce68e4 100644
--- a/synapse/storage/schema/delta/25/tags.sql
+++ b/synapse/storage/data_stores/main/schema/delta/25/tags.sql
diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql
index e395de2b5e..e395de2b5e 100644
--- a/synapse/storage/schema/delta/26/account_data.sql
+++ b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql
diff --git a/synapse/storage/schema/delta/27/account_data.sql b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql
index bf0558b5b3..bf0558b5b3 100644
--- a/synapse/storage/schema/delta/27/account_data.sql
+++ b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql
diff --git a/synapse/storage/schema/delta/27/forgotten_memberships.sql b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql
index e2094f37fe..e2094f37fe 100644
--- a/synapse/storage/schema/delta/27/forgotten_memberships.sql
+++ b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py
index 414f9f5aa0..414f9f5aa0 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/data_stores/main/schema/delta/27/ts.py
diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql
index 4d519849df..4d519849df 100644
--- a/synapse/storage/schema/delta/28/event_push_actions.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql
diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql
index 36609475f1..36609475f1 100644
--- a/synapse/storage/schema/delta/28/events_room_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql
diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql
index 6c1fd68c5b..6c1fd68c5b 100644
--- a/synapse/storage/schema/delta/28/public_roms_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql
diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql
index cb84c69baa..cb84c69baa 100644
--- a/synapse/storage/schema/delta/28/receipts_user_id_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql
diff --git a/synapse/storage/schema/delta/28/upgrade_times.sql b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql
index 3e4a9ab455..3e4a9ab455 100644
--- a/synapse/storage/schema/delta/28/upgrade_times.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql
diff --git a/synapse/storage/schema/delta/28/users_is_guest.sql b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql
index 21d2b420bf..21d2b420bf 100644
--- a/synapse/storage/schema/delta/28/users_is_guest.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql
diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql
index 84b21cf813..84b21cf813 100644
--- a/synapse/storage/schema/delta/29/push_actions.sql
+++ b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql
diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql
index c9d0dde638..c9d0dde638 100644
--- a/synapse/storage/schema/delta/30/alias_creator.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py
index ef7ec34346..9b95411fb6 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/data_stores/main/schema/delta/30/as_users.py
@@ -40,9 +40,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
logger.warning("Could not get app_service_config_files from config")
pass
- appservices = load_appservices(
- config.server_name, config_files
- )
+ appservices = load_appservices(config.server_name, config_files)
owned = {}
@@ -53,20 +51,19 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
if user_id in owned.keys():
logger.error(
"user_id %s was owned by more than one application"
- " service (IDs %s and %s); assigning arbitrarily to %s" %
- (user_id, owned[user_id], appservice.id, owned[user_id])
+ " service (IDs %s and %s); assigning arbitrarily to %s"
+ % (user_id, owned[user_id], appservice.id, owned[user_id])
)
owned.setdefault(appservice.id, []).append(user_id)
for as_id, user_ids in owned.items():
n = 100
- user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n))
+ user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
database_engine.convert_param_style(
- "UPDATE users SET appservice_id = ? WHERE name IN (%s)" % (
- ",".join("?" for _ in chunk),
- )
+ "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+ % (",".join("?" for _ in chunk),)
),
- [as_id] + chunk
+ [as_id] + chunk,
)
diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql
index 712c454aa1..712c454aa1 100644
--- a/synapse/storage/schema/delta/30/deleted_pushers.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql
diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql
index 606bbb037d..606bbb037d 100644
--- a/synapse/storage/schema/delta/30/presence_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql
diff --git a/synapse/storage/schema/delta/30/public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql
index f09db4faa6..f09db4faa6 100644
--- a/synapse/storage/schema/delta/30/public_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql
diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql
index 735aa8d5f6..735aa8d5f6 100644
--- a/synapse/storage/schema/delta/30/push_rule_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql
diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql
index 0dd2f1360c..0dd2f1360c 100644
--- a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql
diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/data_stores/main/schema/delta/31/invites.sql
index 2c57846d5a..2c57846d5a 100644
--- a/synapse/storage/schema/delta/31/invites.sql
+++ b/synapse/storage/data_stores/main/schema/delta/31/invites.sql
diff --git a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql
index 9efb4280eb..9efb4280eb 100644
--- a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql
diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/data_stores/main/schema/delta/31/pushers.py
index 93367fa09e..9bb504aad5 100644
--- a/synapse/storage/schema/delta/31/pushers.py
+++ b/synapse/storage/data_stores/main/schema/delta/31/pushers.py
@@ -24,12 +24,13 @@ logger = logging.getLogger(__name__)
def token_to_stream_ordering(token):
- return int(token[1:].split('_')[0])
+ return int(token[1:].split("_")[0])
def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table, delta 31...")
- cur.execute("""
+ cur.execute(
+ """
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
@@ -48,26 +49,33 @@ def run_create(cur, database_engine, *args, **kwargs):
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
- """)
- cur.execute("""SELECT
+ """
+ )
+ cur.execute(
+ """SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
- """)
+ """
+ )
count = 0
for row in cur.fetchall():
row = list(row)
row[12] = token_to_stream_ordering(row[12])
- cur.execute(database_engine.convert_param_style("""
+ cur.execute(
+ database_engine.convert_param_style(
+ """
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_stream_ordering, last_success,
failing_since
- ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
- row
+ ) values (%s)"""
+ % (",".join(["?" for _ in range(len(row))]))
+ ),
+ row,
)
count += 1
cur.execute("DROP TABLE pushers")
diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql
index a82add88fd..a82add88fd 100644
--- a/synapse/storage/schema/delta/31/pushers_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
index 7d8ca5f93f..7d8ca5f93f 100644
--- a/synapse/storage/schema/delta/31/search_update.py
+++ b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
diff --git a/synapse/storage/schema/delta/32/events.sql b/synapse/storage/data_stores/main/schema/delta/32/events.sql
index 1dd0f9e170..1dd0f9e170 100644
--- a/synapse/storage/schema/delta/32/events.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/events.sql
diff --git a/synapse/storage/schema/delta/32/openid.sql b/synapse/storage/data_stores/main/schema/delta/32/openid.sql
index 36f37b11c8..36f37b11c8 100644
--- a/synapse/storage/schema/delta/32/openid.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/openid.sql
diff --git a/synapse/storage/schema/delta/32/pusher_throttle.sql b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql
index d86d30c13c..d86d30c13c 100644
--- a/synapse/storage/schema/delta/32/pusher_throttle.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql
diff --git a/synapse/storage/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
index 4219cdd06a..2de50d408c 100644
--- a/synapse/storage/schema/delta/32/remove_indices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
@@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream
DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY
-DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT
diff --git a/synapse/storage/schema/delta/32/reports.sql b/synapse/storage/data_stores/main/schema/delta/32/reports.sql
index d13609776f..d13609776f 100644
--- a/synapse/storage/schema/delta/32/reports.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/reports.sql
diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql
index 61ad3fe3e8..61ad3fe3e8 100644
--- a/synapse/storage/schema/delta/33/access_tokens_device_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql
diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/data_stores/main/schema/delta/33/devices.sql
index eca7268d82..eca7268d82 100644
--- a/synapse/storage/schema/delta/33/devices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/devices.sql
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql
index aa4a3b9f2f..aa4a3b9f2f 100644
--- a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
index 6671573398..6671573398 100644
--- a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
index bff1256a7b..bff1256a7b 100644
--- a/synapse/storage/schema/delta/33/event_fields.py
+++ b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py
index 9754d3ccfb..a26057dfb6 100644
--- a/synapse/storage/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py
@@ -26,5 +26,5 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
database_engine.convert_param_style(
"UPDATE remote_media_cache SET last_access_ts = ?"
),
- (int(time.time() * 1000),)
+ (int(time.time() * 1000),),
)
diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql
index 473f75a78e..473f75a78e 100644
--- a/synapse/storage/schema/delta/33/user_ips_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql
diff --git a/synapse/storage/schema/delta/34/appservice_stream.sql b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql
index 69e16eda0f..69e16eda0f 100644
--- a/synapse/storage/schema/delta/34/appservice_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql
diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py
index cf09e43e2b..cf09e43e2b 100644
--- a/synapse/storage/schema/delta/34/cache_stream.py
+++ b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py
diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql
index e68844c74a..e68844c74a 100644
--- a/synapse/storage/schema/delta/34/device_inbox.sql
+++ b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql
diff --git a/synapse/storage/schema/delta/34/push_display_name_rename.sql b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql
index 0d9fe1a99a..0d9fe1a99a 100644
--- a/synapse/storage/schema/delta/34/push_display_name_rename.sql
+++ b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql
diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py
index 67d505e68b..67d505e68b 100644
--- a/synapse/storage/schema/delta/34/received_txn_purge.py
+++ b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py
diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql
index 6cd123027b..6cd123027b 100644
--- a/synapse/storage/schema/delta/35/contains_url.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql
diff --git a/synapse/storage/schema/delta/35/device_outbox.sql b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql
index 17e6c43105..17e6c43105 100644
--- a/synapse/storage/schema/delta/35/device_outbox.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql
diff --git a/synapse/storage/schema/delta/35/device_stream_id.sql b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql
index 7ab7d942e2..7ab7d942e2 100644
--- a/synapse/storage/schema/delta/35/device_stream_id.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql
diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql
index 2e836d8e9c..2e836d8e9c 100644
--- a/synapse/storage/schema/delta/35/event_push_actions_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql
diff --git a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql
index dd2bf2e28a..dd2bf2e28a 100644
--- a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql
diff --git a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql
index 2b945d8a57..2b945d8a57 100644
--- a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql
diff --git a/synapse/storage/schema/delta/36/readd_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql
index 90d8fd18f9..90d8fd18f9 100644
--- a/synapse/storage/schema/delta/36/readd_public_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql
diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py
index a377884169..a377884169 100644
--- a/synapse/storage/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py
diff --git a/synapse/storage/schema/delta/37/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql
index cf7a90dd10..cf7a90dd10 100644
--- a/synapse/storage/schema/delta/37/user_threepids.sql
+++ b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql
diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql
index 515e6b8e84..515e6b8e84 100644
--- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql
diff --git a/synapse/storage/schema/delta/39/appservice_room_list.sql b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql
index 74bdc49073..74bdc49073 100644
--- a/synapse/storage/schema/delta/39/appservice_room_list.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql
diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql
index 00be801e90..00be801e90 100644
--- a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql
diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql
index de2ad93e5c..de2ad93e5c 100644
--- a/synapse/storage/schema/delta/39/event_push_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql
diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql
index 5af814290b..5af814290b 100644
--- a/synapse/storage/schema/delta/39/federation_out_position.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql
diff --git a/synapse/storage/schema/delta/39/membership_profile.sql b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql
index 1bf911c8ab..1bf911c8ab 100644
--- a/synapse/storage/schema/delta/39/membership_profile.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql
diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql
index 7ffa189f39..7ffa189f39 100644
--- a/synapse/storage/schema/delta/40/current_state_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql
diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql
index b9fe1f0480..b9fe1f0480 100644
--- a/synapse/storage/schema/delta/40/device_inbox.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql
diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql
index dd6dcb65f1..dd6dcb65f1 100644
--- a/synapse/storage/schema/delta/40/device_list_streams.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql
diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql
index 3918f0b794..3918f0b794 100644
--- a/synapse/storage/schema/delta/40/event_push_summary.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql
diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql
index 054a223f14..054a223f14 100644
--- a/synapse/storage/schema/delta/40/pushers.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql
diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql
index b7bee8b692..b7bee8b692 100644
--- a/synapse/storage/schema/delta/41/device_list_stream_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql
diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql
index 62f0b9892b..62f0b9892b 100644
--- a/synapse/storage/schema/delta/41/device_outbound_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql
diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql
index 5d9cfecf36..5d9cfecf36 100644
--- a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql
diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql
index a194bf0238..a194bf0238 100644
--- a/synapse/storage/schema/delta/41/ratelimit.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql
diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql
index d28851aff8..d28851aff8 100644
--- a/synapse/storage/schema/delta/42/current_state_delta.sql
+++ b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql
diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql
index 9ab8c14fa3..9ab8c14fa3 100644
--- a/synapse/storage/schema/delta/42/device_list_last_id.sql
+++ b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql
diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql
index b8821ac759..b8821ac759 100644
--- a/synapse/storage/schema/delta/42/event_auth_state_only.sql
+++ b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py
index 506f326f4d..506f326f4d 100644
--- a/synapse/storage/schema/delta/42/user_dir.py
+++ b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py
diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql
index 0e3cd143ff..0e3cd143ff 100644
--- a/synapse/storage/schema/delta/43/blocked_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql
diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql
index 630907ec4f..630907ec4f 100644
--- a/synapse/storage/schema/delta/43/quarantine_media.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql
diff --git a/synapse/storage/schema/delta/43/url_cache.sql b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql
index 45ebe020da..45ebe020da 100644
--- a/synapse/storage/schema/delta/43/url_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql
diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql
index ee7062abe4..ee7062abe4 100644
--- a/synapse/storage/schema/delta/43/user_share.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql
index b12f9b2ebf..b12f9b2ebf 100644
--- a/synapse/storage/schema/delta/44/expire_url_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql
diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql
index b2333848a0..b2333848a0 100644
--- a/synapse/storage/schema/delta/45/group_server.sql
+++ b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql
diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql
index e5ddc84df0..e5ddc84df0 100644
--- a/synapse/storage/schema/delta/45/profile_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql
diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql
index 68c48a89a9..68c48a89a9 100644
--- a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql
diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql
index bb307889c1..bb307889c1 100644
--- a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql
diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql
index 097679bc9a..097679bc9a 100644
--- a/synapse/storage/schema/delta/46/group_server.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql
diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql
index bbfc7f5d1a..bbfc7f5d1a 100644
--- a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql
diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql
index cb0d5a2576..cb0d5a2576 100644
--- a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql
diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql
index d9505f8da1..d9505f8da1 100644
--- a/synapse/storage/schema/delta/46/user_dir_typos.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql
index f505fb22b5..f505fb22b5 100644
--- a/synapse/storage/schema/delta/47/last_access_media.sql
+++ b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql
diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql
index 31d7a817eb..31d7a817eb 100644
--- a/synapse/storage/schema/delta/47/postgres_fts_gin.sql
+++ b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql
diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql
index edccf4a96f..edccf4a96f 100644
--- a/synapse/storage/schema/delta/47/push_actions_staging.sql
+++ b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql
diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql
index 5237491506..5237491506 100644
--- a/synapse/storage/schema/delta/48/add_user_consent.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql
diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql
index 9248b0b24a..9248b0b24a 100644
--- a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql
diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql
index e9013a6969..e9013a6969 100644
--- a/synapse/storage/schema/delta/48/deactivated_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py
index 2233af87d7..49f5f2c003 100644
--- a/synapse/storage/schema/delta/48/group_unique_indexes.py
+++ b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py
@@ -38,16 +38,22 @@ def run_create(cur, database_engine, *args, **kwargs):
rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
# remove duplicates from group_users & group_invites tables
- cur.execute("""
+ cur.execute(
+ """
DELETE FROM group_users WHERE %s NOT IN (
SELECT min(%s) FROM group_users GROUP BY group_id, user_id
);
- """ % (rowid, rowid))
- cur.execute("""
+ """
+ % (rowid, rowid)
+ )
+ cur.execute(
+ """
DELETE FROM group_invites WHERE %s NOT IN (
SELECT min(%s) FROM group_invites GROUP BY group_id, user_id
);
- """ % (rowid, rowid))
+ """
+ % (rowid, rowid)
+ )
for statement in get_statements(FIX_INDEXES.splitlines()):
cur.execute(statement)
diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql
index ce26eaf0c9..ce26eaf0c9 100644
--- a/synapse/storage/schema/delta/48/groups_joinable.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql
diff --git a/synapse/storage/schema/delta/48/profiles_batch.sql b/synapse/storage/data_stores/main/schema/delta/48/profiles_batch.sql
index e744c02fe8..e744c02fe8 100644
--- a/synapse/storage/schema/delta/48/profiles_batch.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/profiles_batch.sql
diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql
index 14dcf18d73..14dcf18d73 100644
--- a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
+++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql
diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql
index 3dd478196f..3dd478196f 100644
--- a/synapse/storage/schema/delta/49/add_user_daily_visits.sql
+++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql
diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
index 3a4ed59b5b..3a4ed59b5b 100644
--- a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql
index c93ae47532..c93ae47532 100644
--- a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql
diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql
index 5d8641a9ab..5d8641a9ab 100644
--- a/synapse/storage/schema/delta/50/erasure_store.sql
+++ b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql
diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py
index 6dd467b6c5..b1684a8441 100644
--- a/synapse/storage/schema/delta/50/make_event_content_nullable.py
+++ b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py
@@ -65,14 +65,18 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
- cur.execute("""
+ cur.execute(
+ """
ALTER TABLE events ALTER COLUMN content DROP NOT NULL;
- """)
+ """
+ )
return
# sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html
- cur.execute("SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'")
+ cur.execute(
+ "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'"
+ )
(oldsql,) = cur.fetchone()
sql = oldsql.replace("content TEXT NOT NULL", "content TEXT")
@@ -86,7 +90,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute("PRAGMA writable_schema=ON")
cur.execute(
"UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'",
- (sql, ),
+ (sql,),
)
cur.execute("PRAGMA schema_version=%i" % (oldver + 1,))
cur.execute("PRAGMA writable_schema=OFF")
diff --git a/synapse/storage/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/50/profiles_deactivated_users.sql
index c8893ecbe8..c8893ecbe8 100644
--- a/synapse/storage/schema/delta/50/profiles_deactivated_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/50/profiles_deactivated_users.sql
diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql
index c0e66a697d..c0e66a697d 100644
--- a/synapse/storage/schema/delta/51/e2e_room_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql
diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql
index c9d537d5a3..c9d537d5a3 100644
--- a/synapse/storage/schema/delta/51/monthly_active_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql
diff --git a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql
index 91e03d13e1..91e03d13e1 100644
--- a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql
diff --git a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql
index bfa49e6f92..bfa49e6f92 100644
--- a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql
diff --git a/synapse/storage/schema/delta/52/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql
index db687cccae..db687cccae 100644
--- a/synapse/storage/schema/delta/52/e2e_room_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql
diff --git a/synapse/storage/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql
index 88ec2f83e5..88ec2f83e5 100644
--- a/synapse/storage/schema/delta/53/add_user_type_to_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql
diff --git a/synapse/storage/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql
index e372f5a44a..e372f5a44a 100644
--- a/synapse/storage/schema/delta/53/drop_sent_transactions.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql
diff --git a/synapse/storage/schema/delta/53/event_format_version.sql b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql
index 1d977c2834..1d977c2834 100644
--- a/synapse/storage/schema/delta/53/event_format_version.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql
diff --git a/synapse/storage/schema/delta/53/user_dir_populate.sql b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql
index ffcc896b58..ffcc896b58 100644
--- a/synapse/storage/schema/delta/53/user_dir_populate.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql
diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql
index b812c5794f..b812c5794f 100644
--- a/synapse/storage/schema/delta/53/user_ips_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql
diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql
index 5831b1a6f8..5831b1a6f8 100644
--- a/synapse/storage/schema/delta/53/user_share.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql
diff --git a/synapse/storage/schema/delta/53/user_threepid_id.sql b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql
index 80c2c573b6..80c2c573b6 100644
--- a/synapse/storage/schema/delta/53/user_threepid_id.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql
diff --git a/synapse/storage/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql
index f7827ca6d2..f7827ca6d2 100644
--- a/synapse/storage/schema/delta/53/users_in_public_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql
diff --git a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql
index 0adb2ad55e..0adb2ad55e 100644
--- a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql
diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql
index c01aa9d2d9..c01aa9d2d9 100644
--- a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql
diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql
index b062ec840c..b062ec840c 100644
--- a/synapse/storage/schema/delta/54/delete_forward_extremities.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql
diff --git a/synapse/storage/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql
index dbbe682697..dbbe682697 100644
--- a/synapse/storage/schema/delta/54/drop_legacy_tables.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql
diff --git a/synapse/storage/schema/delta/54/drop_presence_list.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql
index e6ee70c623..e6ee70c623 100644
--- a/synapse/storage/schema/delta/54/drop_presence_list.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql
diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/data_stores/main/schema/delta/54/relations.sql
index 134862b870..134862b870 100644
--- a/synapse/storage/schema/delta/54/relations.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/relations.sql
diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/data_stores/main/schema/delta/54/stats.sql
index 652e58308e..652e58308e 100644
--- a/synapse/storage/schema/delta/54/stats.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/stats.sql
diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql
index 3b2d48447f..3b2d48447f 100644
--- a/synapse/storage/schema/delta/54/stats2.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql
new file mode 100644
index 0000000000..4590604bfd
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- when this access token can be used until, in ms since the epoch. NULL means the token
+-- never expires.
+ALTER TABLE access_tokens ADD COLUMN valid_until_ms BIGINT;
diff --git a/synapse/storage/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/data_stores/main/schema/delta/55/profile_replication_status_index.sql
index 18a0f7e10c..18a0f7e10c 100644
--- a/synapse/storage/schema/delta/55/profile_replication_status_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/profile_replication_status_index.sql
diff --git a/synapse/storage/schema/delta/55/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/55/room_retention.sql
index ee6cdf7a14..ee6cdf7a14 100644
--- a/synapse/storage/schema/delta/55/room_retention.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/room_retention.sql
diff --git a/synapse/storage/schema/delta/55/track_threepid_validations.sql b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql
index a8eced2e0a..a8eced2e0a 100644
--- a/synapse/storage/schema/delta/55/track_threepid_validations.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql
diff --git a/synapse/storage/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql
index dabdde489b..dabdde489b 100644
--- a/synapse/storage/schema/delta/55/users_alter_deactivated.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql
new file mode 100644
index 0000000000..41807eb1e7
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Opentracing context data for inclusion in the device_list_update EDUs, as a
+ * json-encoded dictionary. NULL if opentracing is disabled (or not enabled for this destination).
+ */
+ALTER TABLE device_lists_outbound_pokes ADD opentracing_context TEXT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql
new file mode 100644
index 0000000000..473018676f
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We add membership to current state so that we don't need to join against
+-- room_memberships, which can be surprisingly costly (we do such queries
+-- very frequently).
+-- This will be null for non-membership events and the content.membership key
+-- for membership events. (Will also be null for membership events until the
+-- background update job has finished).
+ALTER TABLE current_state_events ADD membership TEXT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql
new file mode 100644
index 0000000000..3133d42d4a
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We add membership to current state so that we don't need to join against
+-- room_memberships, which can be surprisingly costly (we do such queries
+-- very frequently).
+-- This will be null for non-membership events and the content.membership key
+-- for membership events. (Will also be null for membership events until the
+-- background update job has finished).
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('current_state_events_membership', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
new file mode 100644
index 0000000000..1d2ddb1b1a
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* delete room keys that belong to deleted room key version, or to room key
+ * versions that don't exist (anymore)
+ */
+DELETE FROM e2e_room_keys
+WHERE version NOT IN (
+ SELECT version
+ FROM e2e_room_keys_versions
+ WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id
+ AND e2e_room_keys_versions.deleted = 0
+);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql
new file mode 100644
index 0000000000..f00889290b
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Record the timestamp when a given server started failing
+ */
+ALTER TABLE destinations ADD failure_ts BIGINT;
+
+/* as a rough approximation, we assume that the server started failing at
+ * retry_interval before the last retry
+ */
+UPDATE destinations SET failure_ts = retry_last_ts - retry_interval
+ WHERE retry_last_ts > 0;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
new file mode 100644
index 0000000000..b9bbb18a91
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
@@ -0,0 +1,18 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We want to store large retry intervals so we upgrade the column from INT
+-- to BIGINT. We don't need to do this on SQLite.
+ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
new file mode 100644
index 0000000000..c2f557fde9
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This line already existed in deltas/35/device_stream_id but was not included in the
+-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist
+INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS (
+ SELECT * from device_max_stream_id
+);
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql
new file mode 100644
index 0000000000..dfa902d0ba
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 Matrix.org Foundation CIC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Track last seen information for a device in the devices table, rather
+-- than relying on it being in the user_ips table (which we want to be able
+-- to purge old entries from)
+ALTER TABLE devices ADD COLUMN last_seen BIGINT;
+ALTER TABLE devices ADD COLUMN ip TEXT;
+ALTER TABLE devices ADD COLUMN user_agent TEXT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('devices_last_seen', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql
new file mode 100644
index 0000000000..9f09922c67
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- these tables are never used.
+DROP TABLE IF EXISTS room_names;
+DROP TABLE IF EXISTS topics;
+DROP TABLE IF EXISTS history_visibility;
+DROP TABLE IF EXISTS guest_access;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
new file mode 100644
index 0000000000..81a36a8b1d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_expiry (
+ event_id TEXT PRIMARY KEY,
+ expiry_ts BIGINT NOT NULL
+);
+
+CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
new file mode 100644
index 0000000000..5e29c1da19
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
@@ -0,0 +1,30 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- room_id and topoligical_ordering are denormalised from the events table in order to
+-- make the index work.
+CREATE TABLE IF NOT EXISTS event_labels (
+ event_id TEXT,
+ label TEXT,
+ room_id TEXT NOT NULL,
+ topological_ordering BIGINT NOT NULL,
+ PRIMARY KEY(event_id, label)
+);
+
+
+-- This index enables an event pagination looking for a particular label to index the
+-- event_labels table first, which is much quicker than scanning the events table and then
+-- filtering by label, if the label is rarely used relative to the size of the room.
+CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql
new file mode 100644
index 0000000000..5f5e0499ae
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('event_store_labels', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql
new file mode 100644
index 0000000000..014cb3b538
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 Matrix.org Foundation CIC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- version is supposed to be part of the room keys index
+CREATE UNIQUE INDEX e2e_room_keys_with_version_idx ON e2e_room_keys(user_id, version, room_id, session_id);
+DROP INDEX IF EXISTS e2e_room_keys_idx;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql
new file mode 100644
index 0000000000..67f8b20297
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- device list needs to know which ones are "real" devices, and which ones are
+-- just used to avoid collisions
+ALTER TABLE devices ADD COLUMN hidden BOOLEAN DEFAULT FALSE;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
new file mode 100644
index 0000000000..e8b1fd35d8
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
@@ -0,0 +1,42 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Change the hidden column from a default value of FALSE to a default value of
+ * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the
+ * string 'FALSE', which is truthy.
+ *
+ * Since sqlite doesn't allow us to just change the default value, we have to
+ * recreate the table, copy the data, fix the rows that have incorrect data, and
+ * replace the old table with the new table.
+ */
+
+CREATE TABLE IF NOT EXISTS devices2 (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ display_name TEXT,
+ last_seen BIGINT,
+ ip TEXT,
+ user_agent TEXT,
+ hidden BOOLEAN DEFAULT 0,
+ CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
+);
+
+INSERT INTO devices2 SELECT * FROM devices;
+
+UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE';
+
+DROP TABLE devices;
+
+ALTER TABLE devices2 RENAME TO devices;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
new file mode 100644
index 0000000000..4f24c1405d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
@@ -0,0 +1,29 @@
+/* Copyright 2019 Werner Sembach
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made.
+DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
diff --git a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql
new file mode 100644
index 0000000000..7be31ffebb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
new file mode 100644
index 0000000000..ea95db0ed7
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
new file mode 100644
index 0000000000..49ce35d794
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('redactions_received_ts', '{}');
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('redactions_have_censored_ts_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
new file mode 100644
index 0000000000..67471f3ef5
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- There was a bug where we may have updated censored redactions as bytes,
+-- which can (somehow) cause json to be inserted hex encoded. These updates go
+-- and undoes any such hex encoded JSON.
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('event_fix_redactions_bytes_create_index', '{}');
+
+INSERT into background_updates (update_name, progress_json, depends_on)
+ VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
new file mode 100644
index 0000000000..b7550f6f4e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS redactions_have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
new file mode 100644
index 0000000000..aeb17813d3
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Now that #6232 is a thing, we can remove old rooms from the directory.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('remove_tombstoned_rooms_from_directory', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
new file mode 100644
index 0000000000..7d70dd071e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- store the current etag of backup version
+ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql
new file mode 100644
index 0000000000..92ab1f5e65
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Adds an index on room_memberships for fetching all forgotten rooms for a user
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('room_membership_forgotten_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
new file mode 100644
index 0000000000..5c5fffcafb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
@@ -0,0 +1,56 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- cross-signing keys
+CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys (
+ user_id TEXT NOT NULL,
+ -- the type of cross-signing key (master, user_signing, or self_signing)
+ keytype TEXT NOT NULL,
+ -- the full key information, as a json-encoded dict
+ keydata TEXT NOT NULL,
+ -- for keeping the keys in order, so that we can fetch the latest one
+ stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, stream_id);
+
+-- cross-signing signatures
+CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures (
+ -- user who did the signing
+ user_id TEXT NOT NULL,
+ -- key used to sign
+ key_id TEXT NOT NULL,
+ -- user who was signed
+ target_user_id TEXT NOT NULL,
+ -- device/key that was signed
+ target_device_id TEXT NOT NULL,
+ -- the actual signature
+ signature TEXT NOT NULL
+);
+
+-- replaced by the index created in signing_keys_nonunique_signatures.sql
+-- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
+
+-- stream of user signature updates
+CREATE TABLE IF NOT EXISTS user_signature_stream (
+ -- uses the same stream ID as device list stream
+ stream_id BIGINT NOT NULL,
+ -- user who did the signing
+ from_user_id TEXT NOT NULL,
+ -- list of users who were signed, as a JSON array
+ user_ids TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
new file mode 100644
index 0000000000..0aa90ebf0c
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* The cross-signing signatures index should not be a unique index, because a
+ * user may upload multiple signatures for the same target user. The previous
+ * index was unique, so delete it if it's there and create a new non-unique
+ * index. */
+
+DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT
+EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
new file mode 100644
index 0000000000..163529c071
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
@@ -0,0 +1,152 @@
+/* Copyright 2018 New Vector Ltd
+ * Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+----- First clean up from previous versions of room stats.
+
+-- First remove old stats stuff
+DROP TABLE IF EXISTS room_stats;
+DROP TABLE IF EXISTS room_state;
+DROP TABLE IF EXISTS room_stats_state;
+DROP TABLE IF EXISTS user_stats;
+DROP TABLE IF EXISTS room_stats_earliest_tokens;
+DROP TABLE IF EXISTS _temp_populate_stats_position;
+DROP TABLE IF EXISTS _temp_populate_stats_rooms;
+DROP TABLE IF EXISTS stats_stream_pos;
+
+-- Unschedule old background updates if they're still scheduled
+DELETE FROM background_updates WHERE update_name IN (
+ 'populate_stats_createtables',
+ 'populate_stats_process_rooms',
+ 'populate_stats_process_users',
+ 'populate_stats_cleanup'
+);
+
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_stats_process_rooms', '{}', '');
+
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_stats_process_users', '{}', 'populate_stats_process_rooms');
+
+----- Create tables for our version of room stats.
+
+-- single-row table to track position of incremental updates
+DROP TABLE IF EXISTS stats_incremental_position;
+CREATE TABLE stats_incremental_position (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_id BIGINT NOT NULL,
+ CHECK (Lock='X')
+);
+
+-- insert a null row and make sure it is the only one.
+INSERT INTO stats_incremental_position (
+ stream_id
+) SELECT COALESCE(MAX(stream_ordering), 0) from events;
+
+-- represents PRESENT room statistics for a room
+-- only holds absolute fields
+DROP TABLE IF EXISTS room_stats_current;
+CREATE TABLE room_stats_current (
+ room_id TEXT NOT NULL PRIMARY KEY,
+
+ -- These are absolute counts
+ current_state_events INT NOT NULL,
+ joined_members INT NOT NULL,
+ invited_members INT NOT NULL,
+ left_members INT NOT NULL,
+ banned_members INT NOT NULL,
+
+ local_users_in_room INT NOT NULL,
+
+ -- The maximum delta stream position that this row takes into account.
+ completed_delta_stream_id BIGINT NOT NULL
+);
+
+
+-- represents HISTORICAL room statistics for a room
+DROP TABLE IF EXISTS room_stats_historical;
+CREATE TABLE room_stats_historical (
+ room_id TEXT NOT NULL,
+ -- These stats cover the time from (end_ts - bucket_size)...end_ts (in ms).
+ -- Note that end_ts is quantised.
+ end_ts BIGINT NOT NULL,
+ bucket_size BIGINT NOT NULL,
+
+ -- These stats are absolute counts
+ current_state_events BIGINT NOT NULL,
+ joined_members BIGINT NOT NULL,
+ invited_members BIGINT NOT NULL,
+ left_members BIGINT NOT NULL,
+ banned_members BIGINT NOT NULL,
+ local_users_in_room BIGINT NOT NULL,
+
+ -- These stats are per time slice
+ total_events BIGINT NOT NULL,
+ total_event_bytes BIGINT NOT NULL,
+
+ PRIMARY KEY (room_id, end_ts)
+);
+
+-- We use this index to speed up deletion of ancient room stats.
+CREATE INDEX room_stats_historical_end_ts ON room_stats_historical (end_ts);
+
+-- represents PRESENT statistics for a user
+-- only holds absolute fields
+DROP TABLE IF EXISTS user_stats_current;
+CREATE TABLE user_stats_current (
+ user_id TEXT NOT NULL PRIMARY KEY,
+
+ joined_rooms BIGINT NOT NULL,
+
+ -- The maximum delta stream position that this row takes into account.
+ completed_delta_stream_id BIGINT NOT NULL
+);
+
+-- represents HISTORICAL statistics for a user
+DROP TABLE IF EXISTS user_stats_historical;
+CREATE TABLE user_stats_historical (
+ user_id TEXT NOT NULL,
+ end_ts BIGINT NOT NULL,
+ bucket_size BIGINT NOT NULL,
+
+ joined_rooms BIGINT NOT NULL,
+
+ invites_sent BIGINT NOT NULL,
+ rooms_created BIGINT NOT NULL,
+ total_events BIGINT NOT NULL,
+ total_event_bytes BIGINT NOT NULL,
+
+ PRIMARY KEY (user_id, end_ts)
+);
+
+-- We use this index to speed up deletion of ancient user stats.
+CREATE INDEX user_stats_historical_end_ts ON user_stats_historical (end_ts);
+
+
+CREATE TABLE room_stats_state (
+ room_id TEXT NOT NULL,
+ name TEXT,
+ canonical_alias TEXT,
+ join_rules TEXT,
+ history_visibility TEXT,
+ encryption TEXT,
+ avatar TEXT,
+ guest_access TEXT,
+ is_federatable BOOLEAN,
+ topic TEXT
+);
+
+CREATE UNIQUE INDEX room_stats_state_room ON room_stats_state(room_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py
new file mode 100644
index 0000000000..1de8b54961
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py
@@ -0,0 +1,52 @@
+import logging
+
+from synapse.storage.engines import PostgresEngine
+
+logger = logging.getLogger(__name__)
+
+
+"""
+This migration updates the user_filters table as follows:
+
+ - drops any (user_id, filter_id) duplicates
+ - makes the columns NON-NULLable
+ - turns the index into a UNIQUE index
+"""
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ select_clause = """
+ SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json
+ FROM user_filters
+ """
+ else:
+ select_clause = """
+ SELECT * FROM user_filters GROUP BY user_id, filter_id
+ """
+ sql = """
+ DROP TABLE IF EXISTS user_filters_migration;
+ DROP INDEX IF EXISTS user_filters_unique;
+ CREATE TABLE user_filters_migration (
+ user_id TEXT NOT NULL,
+ filter_id BIGINT NOT NULL,
+ filter_json BYTEA NOT NULL
+ );
+ INSERT INTO user_filters_migration (user_id, filter_id, filter_json)
+ %s;
+ CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration
+ (user_id, filter_id);
+ DROP TABLE user_filters;
+ ALTER TABLE user_filters_migration RENAME TO user_filters;
+ """ % (
+ select_clause,
+ )
+
+ if isinstance(database_engine, PostgresEngine):
+ cur.execute(sql)
+ else:
+ cur.executescript(sql)
diff --git a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql
new file mode 100644
index 0000000000..91390c4527
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * a table which records mappings from external auth providers to mxids
+ */
+CREATE TABLE IF NOT EXISTS user_external_ids (
+ auth_provider TEXT NOT NULL,
+ external_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ UNIQUE (auth_provider, external_id)
+);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql
new file mode 100644
index 0000000000..149f8be8b6
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 Matrix.org Foundation CIC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this was apparently forgotten when the table was created back in delta 53.
+CREATE INDEX users_in_public_rooms_r_idx ON users_in_public_rooms(room_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql b/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql
new file mode 100644
index 0000000000..aec06c8261
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add background update to go and delete current state events for rooms the
+-- server is no longer in.
+--
+-- this relies on the 'membership' column of current_state_events, so make sure
+-- that's populated first!
+INSERT into background_updates (update_name, progress_json, depends_on)
+ VALUES ('delete_old_current_state_events', '{}', 'current_state_events_membership');
diff --git a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql b/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql
new file mode 100644
index 0000000000..c3b6de2099
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Records whether the server thinks that the remote users cached device lists
+-- may be out of date (e.g. if we have received a to device message from a
+-- device we don't know about).
+CREATE TABLE IF NOT EXISTS device_lists_remote_resync (
+ user_id TEXT NOT NULL,
+ added_ts BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX device_lists_remote_resync_idx ON device_lists_remote_resync (user_id);
+CREATE INDEX device_lists_remote_resync_ts_idx ON device_lists_remote_resync (added_ts);
diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
new file mode 100644
index 0000000000..63b5acdcf7
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# We create a new table called `local_current_membership` that stores the latest
+# membership state of local users in rooms, which helps track leaves/bans/etc
+# even if the server has left the room (and so has deleted the room from
+# `current_state_events`). This will also include outstanding invites for local
+# users for rooms the server isn't in.
+#
+# If the server isn't and hasn't been in the room then it will only include
+# outsstanding invites, and not e.g. pre-emptive bans of local users.
+#
+# If the server later rejoins a room `local_current_membership` can simply be
+# replaced with the new current state of the room (which results in the
+# equivalent behaviour as if the server had remained in the room).
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+ # We need to do the insert in `run_upgrade` section as we don't have access
+ # to `config` in `run_create`.
+
+ # This upgrade may take a bit of time for large servers (e.g. one minute for
+ # matrix.org) but means we avoid a lots of book keeping required to do it as
+ # a background update.
+
+ # We check if the `current_state_events.membership` is up to date by
+ # checking if the relevant background update has finished. If it has
+ # finished we can avoid doing a join against `room_memberships`, which
+ # speesd things up.
+ cur.execute(
+ """SELECT 1 FROM background_updates
+ WHERE update_name = 'current_state_events_membership'
+ """
+ )
+ current_state_membership_up_to_date = not bool(cur.fetchone())
+
+ # Cheekily drop and recreate indices, as that is faster.
+ cur.execute("DROP INDEX local_current_membership_idx")
+ cur.execute("DROP INDEX local_current_membership_room_idx")
+
+ if current_state_membership_up_to_date:
+ sql = """
+ INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+ SELECT c.room_id, state_key AS user_id, event_id, c.membership
+ FROM current_state_events AS c
+ WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key LIKE ?
+ """
+ else:
+ # We can't rely on the membership column, so we need to join against
+ # `room_memberships`.
+ sql = """
+ INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+ SELECT c.room_id, state_key AS user_id, event_id, r.membership
+ FROM current_state_events AS c
+ INNER JOIN room_memberships AS r USING (event_id)
+ WHERE type = 'm.room.member' AND state_key LIKE ?
+ """
+ sql = database_engine.convert_param_style(sql)
+ cur.execute(sql, ("%:" + config.server_name,))
+
+ cur.execute(
+ "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+ )
+ cur.execute(
+ "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+ )
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ cur.execute(
+ """
+ CREATE TABLE local_current_membership (
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ membership TEXT NOT NULL
+ )"""
+ )
+
+ cur.execute(
+ "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+ )
+ cur.execute(
+ "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+ )
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql
new file mode 100644
index 0000000000..352a66f5b0
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- We want to start storing the room version independently of
+-- `current_state_events` so that we can delete stale entries from it without
+-- losing the information.
+ALTER TABLE rooms ADD COLUMN room_version TEXT;
+
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('add_rooms_room_version_column', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres
new file mode 100644
index 0000000000..c601cff6de
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres
@@ -0,0 +1,35 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- when we first added the room_version column, it was populated via a background
+-- update. We now need it to be populated before synapse starts, so we populate
+-- any remaining rows with a NULL room version now. For servers which have completed
+-- the background update, this will be pretty quick.
+
+-- the following query will set room_version to NULL if no create event is found for
+-- the room in current_state_events, and will set it to '1' if a create event with no
+-- room_version is found.
+
+UPDATE rooms SET room_version=(
+ SELECT COALESCE(json::json->'content'->>'room_version','1')
+ FROM current_state_events cse INNER JOIN event_json ej USING (event_id)
+ WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key=''
+) WHERE rooms.room_version IS NULL;
+
+-- we still allow the background update to complete: it has the useful side-effect of
+-- populating `rooms` with any missing rooms (based on the current_state_events table).
+
+-- see also rooms_version_column_2.sql.sqlite which has a copy of the above query, using
+-- sqlite syntax for the json extraction.
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite
new file mode 100644
index 0000000000..335c6f2074
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- see rooms_version_column_2.sql.postgres for details of what's going on here.
+
+UPDATE rooms SET room_version=(
+ SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1')
+ FROM current_state_events cse INNER JOIN event_json ej USING (event_id)
+ WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key=''
+) WHERE rooms.room_version IS NULL;
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
new file mode 100644
index 0000000000..92aaadde0d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
@@ -0,0 +1,39 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- When we first added the room_version column to the rooms table, it was populated from
+-- the current_state_events table. However, there was an issue causing a background
+-- update to clean up the current_state_events table for rooms where the server is no
+-- longer participating, before that column could be populated. Therefore, some rooms had
+-- a NULL room_version.
+
+-- The rooms_version_column_2.sql.* delta files were introduced to make the populating
+-- synchronous instead of running it in a background update, which fixed this issue.
+-- However, all of the instances of Synapse installed or updated in the meantime got
+-- their rooms table corrupted with NULL room_versions.
+
+-- This query fishes out the room versions from the create event using the state_events
+-- table instead of the current_state_events one, as the former still have all of the
+-- create events.
+
+UPDATE rooms SET room_version=(
+ SELECT COALESCE(json::json->'content'->>'room_version','1')
+ FROM state_events se INNER JOIN event_json ej USING (event_id)
+ WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
+ LIMIT 1
+) WHERE rooms.room_version IS NULL;
+
+-- see also rooms_version_column_3.sql.sqlite which has a copy of the above query, using
+-- sqlite syntax for the json extraction.
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
new file mode 100644
index 0000000000..e19dab97cb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
@@ -0,0 +1,23 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- see rooms_version_column_3.sql.postgres for details of what's going on here.
+
+UPDATE rooms SET room_version=(
+ SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1')
+ FROM state_events se INNER JOIN event_json ej USING (event_id)
+ WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
+ LIMIT 1
+) WHERE rooms.room_version IS NULL;
diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql
index 883fcd10b2..883fcd10b2 100644
--- a/synapse/storage/schema/full_schemas/16/application_services.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql
diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql
index 10ce2aa7a0..10ce2aa7a0 100644
--- a/synapse/storage/schema/full_schemas/16/event_edges.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql
diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql
index 95826da431..95826da431 100644
--- a/synapse/storage/schema/full_schemas/16/event_signatures.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql
diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql
index a1a2aa8e5b..a1a2aa8e5b 100644
--- a/synapse/storage/schema/full_schemas/16/im.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql
diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql
index 11cdffdbb3..11cdffdbb3 100644
--- a/synapse/storage/schema/full_schemas/16/keys.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql
diff --git a/synapse/storage/schema/full_schemas/16/media_repository.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql
index 8f3759bb2a..8f3759bb2a 100644
--- a/synapse/storage/schema/full_schemas/16/media_repository.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql
diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql
index 01d2d8f833..01d2d8f833 100644
--- a/synapse/storage/schema/full_schemas/16/presence.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql
diff --git a/synapse/storage/schema/full_schemas/16/profiles.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql
index c04f4747d9..c04f4747d9 100644
--- a/synapse/storage/schema/full_schemas/16/profiles.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql
diff --git a/synapse/storage/schema/full_schemas/16/push.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql
index e44465cf45..e44465cf45 100644
--- a/synapse/storage/schema/full_schemas/16/push.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql
diff --git a/synapse/storage/schema/full_schemas/16/redactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql
index 318f0d9aa5..318f0d9aa5 100644
--- a/synapse/storage/schema/full_schemas/16/redactions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql
diff --git a/synapse/storage/schema/full_schemas/16/room_aliases.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql
index d47da3b12f..d47da3b12f 100644
--- a/synapse/storage/schema/full_schemas/16/room_aliases.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql
diff --git a/synapse/storage/schema/full_schemas/16/state.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql
index 96391a8f0e..96391a8f0e 100644
--- a/synapse/storage/schema/full_schemas/16/state.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql
diff --git a/synapse/storage/schema/full_schemas/16/transactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql
index 17e67bedac..17e67bedac 100644
--- a/synapse/storage/schema/full_schemas/16/transactions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql
diff --git a/synapse/storage/schema/full_schemas/16/users.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql
index f013aa8b18..f013aa8b18 100644
--- a/synapse/storage/schema/full_schemas/16/users.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 01a2b0e024..20c5af2eb7 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -70,15 +70,6 @@ CREATE TABLE appservice_stream_position (
);
-
-CREATE TABLE background_updates (
- update_name text NOT NULL,
- progress_json text NOT NULL,
- depends_on text
-);
-
-
-
CREATE TABLE blocked_rooms (
room_id text NOT NULL,
user_id text NOT NULL
@@ -993,40 +984,6 @@ CREATE TABLE state_events (
-CREATE TABLE state_group_edges (
- state_group bigint NOT NULL,
- prev_state_group bigint NOT NULL
-);
-
-
-
-CREATE SEQUENCE state_group_id_seq
- START WITH 1
- INCREMENT BY 1
- NO MINVALUE
- NO MAXVALUE
- CACHE 1;
-
-
-
-CREATE TABLE state_groups (
- id bigint NOT NULL,
- room_id text NOT NULL,
- event_id text NOT NULL
-);
-
-
-
-CREATE TABLE state_groups_state (
- state_group bigint NOT NULL,
- room_id text NOT NULL,
- type text NOT NULL,
- state_key text NOT NULL,
- event_id text NOT NULL
-);
-
-
-
CREATE TABLE stats_stream_pos (
lock character(1) DEFAULT 'X'::bpchar NOT NULL,
stream_id bigint,
@@ -1211,11 +1168,6 @@ ALTER TABLE ONLY appservice_stream_position
-ALTER TABLE ONLY background_updates
- ADD CONSTRAINT background_updates_uniqueness UNIQUE (update_name);
-
-
-
ALTER TABLE ONLY current_state_events
ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id);
@@ -1505,12 +1457,6 @@ ALTER TABLE ONLY state_events
ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id);
-
-ALTER TABLE ONLY state_groups
- ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id);
-
-
-
ALTER TABLE ONLY stats_stream_pos
ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock);
@@ -1955,18 +1901,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts);
-CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group);
-
-
-
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group);
-
-
-
-CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key);
-
-
-
CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering);
@@ -2060,6 +1994,3 @@ CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_room
CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id);
-
-
-
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index f1a71627f0..e28ec3fa45 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id);
CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) );
CREATE INDEX room_depth_room ON room_depth(room_id);
-CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
-CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL );
CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) );
CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) );
CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) );
@@ -67,7 +65,6 @@ CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id );
CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id );
CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) );
CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
-CREATE TABLE background_updates( update_name TEXT NOT NULL, progress_json TEXT NOT NULL, depends_on TEXT, CONSTRAINT background_updates_uniqueness UNIQUE (update_name) );
CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value )
/* event_search(event_id,room_id,sender,"key",value) */;
CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value');
@@ -121,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL );
CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT);
CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id );
CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id );
-CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL );
-CREATE INDEX state_group_edges_idx ON state_group_edges(state_group);
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group);
CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering );
CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering );
@@ -257,6 +251,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen);
CREATE INDEX users_creation_ts ON users (creation_ts);
CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group);
CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id);
-CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key);
CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id);
CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip);
diff --git a/synapse/storage/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
index c265fd20e2..91d21b2921 100644
--- a/synapse/storage/schema/full_schemas/54/stream_positions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
@@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales
INSERT INTO user_directory_stream_pos (stream_id) VALUES (0);
INSERT INTO stats_stream_pos (stream_id) VALUES (0);
INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);
+-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..c00f287190
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,21 @@
+# Synapse Database Schemas
+
+These schemas are used as a basis to create brand new Synapse databases, on both
+SQLite3 and Postgres.
+
+## Building full schema dumps
+
+If you want to recreate these schemas, they need to be made from a database that
+has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce new
+`full.sql.postgres ` and `full.sql.sqlite` files.
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`, then call
+
+ ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+
+There are currently two folders with full-schema snapshots. `16` is a snapshot
+from 2015, for historical reference. The other contains the most recent full
+schema snapshot.
diff --git a/synapse/storage/search.py b/synapse/storage/data_stores/main/search.py
index ff49eaae02..47ebb8a214 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -24,35 +24,36 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from .background_updates import BackgroundUpdateStore
-
logger = logging.getLogger(__name__)
SearchEntry = namedtuple(
- 'SearchEntry',
- ['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'],
+ "SearchEntry",
+ ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"],
)
-class SearchStore(BackgroundUpdateStore):
+class SearchBackgroundUpdateStore(SQLBaseStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, db_conn, hs):
- super(SearchStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
if not hs.config.enable_search:
return
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
)
@@ -61,9 +62,11 @@ class SearchStore(BackgroundUpdateStore):
# a GIN index. However, it's possible that some people might still have
# the background update queued, so we register a handler to clear the
# background update.
- self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+ self.db.updates.register_noop_background_update(
+ self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME
+ )
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)
@@ -93,7 +96,7 @@ class SearchStore(BackgroundUpdateStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
return 0
@@ -153,20 +156,20 @@ class SearchStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(event_search_rows),
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.EVENT_SEARCH_UPDATE_NAME, progress
)
return len(event_search_rows)
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
if not result:
- yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
+ yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _background_reindex_gin_search(self, progress, batch_size):
@@ -196,7 +199,7 @@ class SearchStore(BackgroundUpdateStore):
" ON event_search USING GIN (vector)"
)
except psycopg2.ProgrammingError as e:
- logger.warn(
+ logger.warning(
"Ignoring error %r when trying to switch from GIST to GIN", e
)
@@ -206,17 +209,19 @@ class SearchStore(BackgroundUpdateStore):
conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
- yield self.runWithConnection(create_index)
+ yield self.db.runWithConnection(create_index)
- yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
- defer.returnValue(1)
+ yield self.db.updates._end_background_update(
+ self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
+ )
+ return 1
@defer.inlineCallbacks
def _background_reindex_search_order(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- have_added_index = progress['have_added_indexes']
+ have_added_index = progress["have_added_indexes"]
if not have_added_index:
@@ -237,14 +242,14 @@ class SearchStore(BackgroundUpdateStore):
)
conn.set_session(autocommit=False)
- yield self.runWithConnection(create_index)
+ yield self.db.runWithConnection(create_index)
pg = dict(progress)
pg["have_added_indexes"] = True
- yield self.runInteraction(
+ yield self.db.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
- self._background_update_progress_txn,
+ self.db.updates._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
pg,
)
@@ -274,43 +279,22 @@ class SearchStore(BackgroundUpdateStore):
"have_added_indexes": True,
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
)
return len(rows), True
- num_rows, finished = yield self.runInteraction(
+ num_rows, finished = yield self.db.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
)
if not finished:
- yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
-
- defer.returnValue(num_rows)
-
- def store_event_search_txn(self, txn, event, key, value):
- """Add event to the search table
+ yield self.db.updates._end_background_update(
+ self.EVENT_SEARCH_ORDER_UPDATE_NAME
+ )
- Args:
- txn (cursor):
- event (EventBase):
- key (str):
- value (str):
- """
- self.store_search_entries_txn(
- txn,
- (
- SearchEntry(
- key=key,
- value=value,
- event_id=event.event_id,
- room_id=event.room_id,
- stream_ordering=event.internal_metadata.stream_ordering,
- origin_server_ts=event.origin_server_ts,
- ),
- ),
- )
+ return num_rows
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
@@ -341,29 +325,7 @@ class SearchStore(BackgroundUpdateStore):
for entry in entries
)
- # inserts to a GIN index are normally batched up into a pending
- # list, and then all committed together once the list gets to a
- # certain size. The trouble with that is that postgres (pre-9.5)
- # uses work_mem to determine the length of the list, and work_mem
- # is typically very large.
- #
- # We therefore reduce work_mem while we do the insert.
- #
- # (postgres 9.5 uses the separate gin_pending_list_limit setting,
- # so doesn't suffer the same problem, but changing work_mem will
- # be harmless)
- #
- # Note that we don't need to worry about restoring it on
- # exception, because exceptions will cause the transaction to be
- # rolled back, including the effects of the SET command.
- #
- # Also: we use SET rather than SET LOCAL because there's lots of
- # other stuff going on in this transaction, which want to have the
- # normal work_mem setting.
-
- txn.execute("SET work_mem='256kB'")
txn.executemany(sql, args)
- txn.execute("RESET work_mem")
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
@@ -380,6 +342,34 @@ class SearchStore(BackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
+
+class SearchStore(SearchBackgroundUpdateStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(SearchStore, self).__init__(database, db_conn, hs)
+
+ def store_event_search_txn(self, txn, event, key, value):
+ """Add event to the search table
+
+ Args:
+ txn (cursor):
+ event (EventBase):
+ key (str):
+ value (str):
+ """
+ self.store_search_entries_txn(
+ txn,
+ (
+ SearchEntry(
+ key=key,
+ value=value,
+ event_id=event.event_id,
+ room_id=event.room_id,
+ stream_ordering=event.internal_metadata.stream_ordering,
+ origin_server_ts=event.origin_server_ts,
+ ),
+ ),
+ )
+
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
@@ -395,15 +385,17 @@ class SearchStore(BackgroundUpdateStore):
"""
clauses = []
- search_query = search_query = _parse_query(self.database_engine, search_term)
+ search_query = _parse_query(self.database_engine, search_term)
args = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
- clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
- args.extend(room_ids)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+ clauses = [clause]
local_clauses = []
for key in keys:
@@ -456,11 +448,18 @@ class SearchStore(BackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args)
+ results = yield self.db.execute(
+ "search_msgs", self.db.cursor_to_dict, sql, *args
+ )
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
@@ -470,23 +469,21 @@ class SearchStore(BackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = yield self._execute(
- "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+ count_results = yield self.db.execute(
+ "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
- defer.returnValue(
- {
- "results": [
- {"event": event_map[r["event_id"]], "rank": r["rank"]}
- for r in results
- if r["event_id"] in event_map
- ],
- "highlights": highlights,
- "count": count,
- }
- )
+ return {
+ "results": [
+ {"event": event_map[r["event_id"]], "rank": r["rank"]}
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ "count": count,
+ }
@defer.inlineCallbacks
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
@@ -504,15 +501,17 @@ class SearchStore(BackgroundUpdateStore):
"""
clauses = []
- search_query = search_query = _parse_query(self.database_engine, search_term)
+ search_query = _parse_query(self.database_engine, search_term)
args = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
- clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
- args.extend(room_ids)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+ clauses = [clause]
local_clauses = []
for key in keys:
@@ -601,11 +600,18 @@ class SearchStore(BackgroundUpdateStore):
args.append(limit)
- results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args)
+ results = yield self.db.execute(
+ "search_rooms", self.db.cursor_to_dict, sql, *args
+ )
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
@@ -615,28 +621,26 @@ class SearchStore(BackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = yield self._execute(
- "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+ count_results = yield self.db.execute(
+ "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
- defer.returnValue(
- {
- "results": [
- {
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s"
- % (r["origin_server_ts"], r["stream_ordering"]),
- }
- for r in results
- if r["event_id"] in event_map
- ],
- "highlights": highlights,
- "count": count,
- }
- )
+ return {
+ "results": [
+ {
+ "event": event_map[r["event_id"]],
+ "rank": r["rank"],
+ "pagination_token": "%s,%s"
+ % (r["origin_server_ts"], r["stream_ordering"]),
+ }
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ "count": count,
+ }
def _find_highlights_in_postgres(self, search_query, events):
"""Given a list of events and a search term, return a list of words
@@ -689,7 +693,7 @@ class SearchStore(BackgroundUpdateStore):
)
)
txn.execute(query, (value, search_query))
- headline, = txn.fetchall()[0]
+ (headline,) = txn.fetchall()[0]
# Now we need to pick the possible highlights out of the haedline
# result.
@@ -703,7 +707,7 @@ class SearchStore(BackgroundUpdateStore):
return highlight_words
- return self.runInteraction("_find_highlights", f)
+ return self.db.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
diff --git a/synapse/storage/signatures.py b/synapse/storage/data_stores/main/signatures.py
index 6bd81e84ad..563216b63c 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -20,10 +20,9 @@ from unpaddedbase64 import encode_base64
from twisted.internet import defer
from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
-from ._base import SQLBaseStore
-
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
@@ -49,7 +48,7 @@ class SignatureWorkerStore(SQLBaseStore):
for event_id in event_ids
}
- return self.runInteraction("get_event_reference_hashes", f)
+ return self.db.runInteraction("get_event_reference_hashes", f)
@defer.inlineCallbacks
def add_event_hashes(self, event_ids):
@@ -59,7 +58,7 @@ class SignatureWorkerStore(SQLBaseStore):
for e_id, h in hashes.items()
}
- defer.returnValue(list(hashes.items()))
+ return list(hashes.items())
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
@@ -99,4 +98,4 @@ class SignatureStore(SignatureWorkerStore):
}
)
- self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
+ self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
new file mode 100644
index 0000000000..3a3b9a8e72
--- /dev/null
+++ b/synapse/storage/data_stores/main/state.py
@@ -0,0 +1,505 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import collections.abc
+import logging
+from collections import namedtuple
+from typing import Iterable, Tuple
+
+from six import iteritems
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
+from synapse.storage.state import StateFilter
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.stringutils import to_ascii
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class _GetStateGroupDelta(
+ namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+
+ __slots__ = []
+
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
+
+
+# this inherits from EventsWorkerStore because it calls self.get_events
+class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
+ """The parts of StateGroupStore that can be called from workers.
+ """
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
+
+ async def get_room_version(self, room_id: str) -> RoomVersion:
+ """Get the room_version of a given room
+
+ Raises:
+ NotFoundError: if the room is unknown
+
+ UnsupportedRoomVersionError: if the room uses an unknown room version.
+ Typically this happens if support for the room's version has been
+ removed from Synapse.
+ """
+ room_version_id = await self.get_room_version_id(room_id)
+ v = KNOWN_ROOM_VERSIONS.get(room_version_id)
+
+ if not v:
+ raise UnsupportedRoomVersionError(
+ "Room %s uses a room version %s which is no longer supported"
+ % (room_id, room_version_id)
+ )
+
+ return v
+
+ @cached(max_entries=10000)
+ async def get_room_version_id(self, room_id: str) -> str:
+ """Get the room_version of a given room
+
+ Raises:
+ NotFoundError: if the room is unknown
+ """
+
+ # First we try looking up room version from the database, but for old
+ # rooms we might not have added the room version to it yet so we fall
+ # back to previous behaviour and look in current state events.
+
+ # We really should have an entry in the rooms table for every room we
+ # care about, but let's be a bit paranoid (at least while the background
+ # update is happening) to avoid breaking existing rooms.
+ version = await self.db.simple_select_one_onecol(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcol="room_version",
+ desc="get_room_version",
+ allow_none=True,
+ )
+
+ if version is not None:
+ return version
+
+ # Retrieve the room's create event
+ create_event = await self.get_create_event_for_room(room_id)
+ return create_event.content.get("room_version", "1")
+
+ @defer.inlineCallbacks
+ def get_room_predecessor(self, room_id):
+ """Get the predecessor of an upgraded room if it exists.
+ Otherwise return None.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[dict|None]: A dictionary containing the structure of the predecessor
+ field from the room's create event. The structure is subject to other servers,
+ but it is expected to be:
+ * room_id (str): The room ID of the predecessor room
+ * event_id (str): The ID of the tombstone event in the predecessor room
+
+ None if a predecessor key is not found, or is not a dictionary.
+
+ Raises:
+ NotFoundError if the given room is unknown
+ """
+ # Retrieve the room's create event
+ create_event = yield self.get_create_event_for_room(room_id)
+
+ # Retrieve the predecessor key of the create event
+ predecessor = create_event.content.get("predecessor", None)
+
+ # Ensure the key is a dictionary
+ if not isinstance(predecessor, collections.abc.Mapping):
+ return None
+
+ return predecessor
+
+ @defer.inlineCallbacks
+ def get_create_event_for_room(self, room_id):
+ """Get the create state event for a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[EventBase]: The room creation event.
+
+ Raises:
+ NotFoundError if the room is unknown
+ """
+ state_ids = yield self.get_current_state_ids(room_id)
+ create_id = state_ids.get((EventTypes.Create, ""))
+
+ # If we can't find the create event, assume we've hit a dead end
+ if not create_id:
+ raise NotFoundError("Unknown room %s" % (room_id,))
+
+ # Retrieve the room's create event and return
+ create_event = yield self.get_event(create_id)
+ return create_event
+
+ @cached(max_entries=100000, iterable=True)
+ def get_current_state_ids(self, room_id):
+ """Get the current state event ids for a room based on the
+ current_state_events table.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ deferred: dict of (type, state_key) -> event_id
+ """
+
+ def _get_current_state_ids_txn(txn):
+ txn.execute(
+ """SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ?
+ """,
+ (room_id,),
+ )
+
+ return {
+ (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
+ }
+
+ return self.db.runInteraction(
+ "get_current_state_ids", _get_current_state_ids_txn
+ )
+
+ # FIXME: how should this be cached?
+ def get_filtered_current_state_ids(
+ self, room_id: str, state_filter: StateFilter = StateFilter.all()
+ ):
+ """Get the current state event of a given type for a room based on the
+ current_state_events table. This may not be as up-to-date as the result
+ of doing a fresh state resolution as per state_handler.get_current_state
+
+ Args:
+ room_id
+ state_filter: The state filter used to fetch state
+ from the database.
+
+ Returns:
+ defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
+ """
+
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ if not where_clause:
+ # We delegate to the cached version
+ return self.get_current_state_ids(room_id)
+
+ def _get_filtered_current_state_ids_txn(txn):
+ results = {}
+ sql = """
+ SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ?
+ """
+
+ if where_clause:
+ sql += " AND (%s)" % (where_clause,)
+
+ args = [room_id]
+ args.extend(where_args)
+ txn.execute(sql, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (intern_string(typ), intern_string(state_key))
+ results[key] = event_id
+
+ return results
+
+ return self.db.runInteraction(
+ "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_canonical_alias_for_room(self, room_id):
+ """Get canonical alias for room, if any
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[str|None]: The canonical alias, if any
+ """
+
+ state = yield self.get_filtered_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+ )
+
+ event_id = state.get((EventTypes.CanonicalAlias, ""))
+ if not event_id:
+ return
+
+ event = yield self.get_event(event_id, allow_none=True)
+ if not event:
+ return
+
+ return event.content.get("canonical_alias")
+
+ @cached(max_entries=50000)
+ def _get_state_group_for_event(self, event_id):
+ return self.db.simple_select_one_onecol(
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
+ allow_none=True,
+ desc="_get_state_group_for_event",
+ )
+
+ @cachedList(
+ cached_method_name="_get_state_group_for_event",
+ list_name="event_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
+ def _get_state_group_for_events(self, event_ids):
+ """Returns mapping event_id -> state_group
+ """
+ rows = yield self.db.simple_select_many_batch(
+ table="event_to_state_groups",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id", "state_group"),
+ desc="_get_state_group_for_events",
+ )
+
+ return {row["event_id"]: row["state_group"] for row in rows}
+
+ @defer.inlineCallbacks
+ def get_referenced_state_groups(self, state_groups):
+ """Check if the state groups are referenced by events.
+
+ Args:
+ state_groups (Iterable[int])
+
+ Returns:
+ Deferred[set[int]]: The subset of state groups that are
+ referenced.
+ """
+
+ rows = yield self.db.simple_select_many_batch(
+ table="event_to_state_groups",
+ column="state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("DISTINCT state_group",),
+ desc="get_referenced_state_groups",
+ )
+
+ return {row["state_group"] for row in rows}
+
+
+class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
+
+ CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+ EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
+ DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+ self.server_name = hs.hostname
+
+ self.db.updates.register_background_index_update(
+ self.CURRENT_STATE_INDEX_UPDATE_NAME,
+ index_name="current_state_events_member_index",
+ table="current_state_events",
+ columns=["state_key"],
+ where_clause="type='m.room.member'",
+ )
+ self.db.updates.register_background_index_update(
+ self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
+ index_name="event_to_state_groups_sg_index",
+ table="event_to_state_groups",
+ columns=["state_group"],
+ )
+ self.db.updates.register_background_update_handler(
+ self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
+ )
+
+ async def _background_remove_left_rooms(self, progress, batch_size):
+ """Background update to delete rows from `current_state_events` and
+ `event_forward_extremities` tables of rooms that the server is no
+ longer joined to.
+ """
+
+ last_room_id = progress.get("last_room_id", "")
+
+ def _background_remove_left_rooms_txn(txn):
+ sql = """
+ SELECT DISTINCT room_id FROM current_state_events
+ WHERE room_id > ? ORDER BY room_id LIMIT ?
+ """
+
+ txn.execute(sql, (last_room_id, batch_size))
+ room_ids = [row[0] for row in txn]
+ if not room_ids:
+ return True, set()
+
+ sql = """
+ SELECT room_id
+ FROM current_state_events
+ WHERE
+ room_id > ? AND room_id <= ?
+ AND type = 'm.room.member'
+ AND membership = 'join'
+ AND state_key LIKE ?
+ GROUP BY room_id
+ """
+
+ txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
+
+ joined_room_ids = {row[0] for row in txn}
+
+ left_rooms = set(room_ids) - joined_room_ids
+
+ logger.info("Deleting current state left rooms: %r", left_rooms)
+
+ # First we get all users that we still think were joined to the
+ # room. This is so that we can mark those device lists as
+ # potentially stale, since there may have been a period where the
+ # server didn't share a room with the remote user and therefore may
+ # have missed any device updates.
+ rows = self.db.simple_select_many_txn(
+ txn,
+ table="current_state_events",
+ column="room_id",
+ iterable=left_rooms,
+ keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
+ retcols=("state_key",),
+ )
+
+ potentially_left_users = {row["state_key"] for row in rows}
+
+ # Now lets actually delete the rooms from the DB.
+ self.db.simple_delete_many_txn(
+ txn,
+ table="current_state_events",
+ column="room_id",
+ iterable=left_rooms,
+ keyvalues={},
+ )
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="event_forward_extremities",
+ column="room_id",
+ iterable=left_rooms,
+ keyvalues={},
+ )
+
+ self.db.updates._background_update_progress_txn(
+ txn,
+ self.DELETE_CURRENT_STATE_UPDATE_NAME,
+ {"last_room_id": room_ids[-1]},
+ )
+
+ return False, potentially_left_users
+
+ finished, potentially_left_users = await self.db.runInteraction(
+ "_background_remove_left_rooms", _background_remove_left_rooms_txn
+ )
+
+ if finished:
+ await self.db.updates._end_background_update(
+ self.DELETE_CURRENT_STATE_UPDATE_NAME
+ )
+
+ # Now go and check if we still share a room with the remote users in
+ # the deleted rooms. If not mark their device lists as stale.
+ joined_users = await self.get_users_server_still_shares_room_with(
+ potentially_left_users
+ )
+
+ for user_id in potentially_left_users - joined_users:
+ await self.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+ return batch_size
+
+
+class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
+ """ Keeps track of the state at a given event.
+
+ This is done by the concept of `state groups`. Every event is a assigned
+ a state group (identified by an arbitrary string), which references a
+ collection of state events. The current state of an event is then the
+ collection of state events referenced by the event's state group.
+
+ Hence, every change in the current state causes a new state group to be
+ generated. However, if no change happens (e.g., if we get a message event
+ with only one parent it inherits the state group from its parent.)
+
+ There are three tables:
+ * `state_groups`: Stores group name, first event with in the group and
+ room id.
+ * `event_to_state_groups`: Maps events to state groups.
+ * `state_groups_state`: Maps state group to state events.
+ """
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateStore, self).__init__(database, db_conn, hs)
+
+ def _store_event_state_mappings_txn(
+ self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
+ ):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
+
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.state_group_before_event
+ continue
+
+ state_groups[event.event_id] = context.state_group
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="event_to_state_groups",
+ values=[
+ {"state_group": state_group_id, "event_id": event_id}
+ for event_id, state_group_id in iteritems(state_groups)
+ ],
+ )
+
+ for event_id, state_group_id in iteritems(state_groups):
+ txn.call_after(
+ self._get_state_group_for_event.prefill, (event_id,), state_group_id
+ )
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
index 5fdb442104..725e12507f 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/data_stores/main/state_deltas.py
@@ -15,13 +15,15 @@
import logging
+from twisted.internet import defer
+
from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
- def get_current_state_deltas(self, prev_stream_id):
+ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
@@ -36,15 +38,27 @@ class StateDeltasStore(SQLBaseStore):
Args:
prev_stream_id (int): point to get changes since (exclusive)
+ max_stream_id (int): the point that we know has been correctly persisted
+ - ie, an upper limit to return changes from.
Returns:
- Deferred[list[dict]]: results
+ Deferred[tuple[int, list[dict]]: A tuple consisting of:
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
"""
prev_stream_id = int(prev_stream_id)
+
+ # check we're not going backwards
+ assert prev_stream_id <= max_stream_id
+
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
):
- return []
+ # if the CSDs haven't changed between prev_stream_id and now, we
+ # know for certain that they haven't changed between prev_stream_id and
+ # max_stream_id.
+ return defer.succeed((max_stream_id, []))
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
@@ -54,21 +68,29 @@ class StateDeltasStore(SQLBaseStore):
sql = """
SELECT stream_id, count(*)
FROM current_state_delta_stream
- WHERE stream_id > ?
+ WHERE stream_id > ? AND stream_id <= ?
GROUP BY stream_id
ORDER BY stream_id ASC
LIMIT 100
"""
- txn.execute(sql, (prev_stream_id,))
+ txn.execute(sql, (prev_stream_id, max_stream_id))
total = 0
- max_stream_id = prev_stream_id
- for max_stream_id, count in txn:
+
+ for stream_id, count in txn:
total += count
if total > 100:
# We arbitarily limit to 100 entries to ensure we don't
# select toooo many.
+ logger.debug(
+ "Clipping current_state_delta_stream rows to stream_id %i",
+ stream_id,
+ )
+ clipped_stream_id = stream_id
break
+ else:
+ # if there's no problem, we may as well go right up to the max_stream_id
+ clipped_stream_id = max_stream_id
# Now actually get the deltas
sql = """
@@ -77,15 +99,15 @@ class StateDeltasStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
- txn.execute(sql, (prev_stream_id, max_stream_id))
- return self.cursor_to_dict(txn)
+ txn.execute(sql, (prev_stream_id, clipped_stream_id))
+ return clipped_stream_id, self.db.cursor_to_dict(txn)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
- return self._simple_select_one_onecol_txn(
+ return self.db.simple_select_one_onecol_txn(
txn,
table="current_state_delta_stream",
keyvalues={},
@@ -93,7 +115,7 @@ class StateDeltasStore(SQLBaseStore):
)
def get_max_stream_id_in_current_state_deltas(self):
- return self.runInteraction(
+ return self.db.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
new file mode 100644
index 0000000000..380c1ec7da
--- /dev/null
+++ b/synapse/storage/data_stores/main/stats.py
@@ -0,0 +1,857 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018, 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from itertools import chain
+
+from twisted.internet import defer
+from twisted.internet.defer import DeferredLock
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import cached
+
+logger = logging.getLogger(__name__)
+
+# these fields track absolutes (e.g. total number of rooms on the server)
+# You can think of these as Prometheus Gauges.
+# You can draw these stats on a line graph.
+# Example: number of users in a room
+ABSOLUTE_STATS_FIELDS = {
+ "room": (
+ "current_state_events",
+ "joined_members",
+ "invited_members",
+ "left_members",
+ "banned_members",
+ "local_users_in_room",
+ ),
+ "user": ("joined_rooms",),
+}
+
+# these fields are per-timeslice and so should be reset to 0 upon a new slice
+# You can draw these stats on a histogram.
+# Example: number of events sent locally during a time slice
+PER_SLICE_FIELDS = {
+ "room": ("total_events", "total_event_bytes"),
+ "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"),
+}
+
+TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")}
+
+# these are the tables (& ID columns) which contain our actual subjects
+TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
+
+
+class StatsStore(StateDeltasStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(StatsStore, self).__init__(database, db_conn, hs)
+
+ self.server_name = hs.hostname
+ self.clock = self.hs.get_clock()
+ self.stats_enabled = hs.config.stats_enabled
+ self.stats_bucket_size = hs.config.stats_bucket_size
+
+ self.stats_delta_processing_lock = DeferredLock()
+
+ self.db.updates.register_background_update_handler(
+ "populate_stats_process_rooms", self._populate_stats_process_rooms
+ )
+ self.db.updates.register_background_update_handler(
+ "populate_stats_process_users", self._populate_stats_process_users
+ )
+ # we no longer need to perform clean-up, but we will give ourselves
+ # the potential to reintroduce it in the future – so documentation
+ # will still encourage the use of this no-op handler.
+ self.db.updates.register_noop_background_update("populate_stats_cleanup")
+ self.db.updates.register_noop_background_update("populate_stats_prepare")
+
+ def quantise_stats_time(self, ts):
+ """
+ Quantises a timestamp to be a multiple of the bucket size.
+
+ Args:
+ ts (int): the timestamp to quantise, in milliseconds since the Unix
+ Epoch
+
+ Returns:
+ int: a timestamp which
+ - is divisible by the bucket size;
+ - is no later than `ts`; and
+ - is the largest such timestamp.
+ """
+ return (ts // self.stats_bucket_size) * self.stats_bucket_size
+
+ @defer.inlineCallbacks
+ def _populate_stats_process_users(self, progress, batch_size):
+ """
+ This is a background update which regenerates statistics for users.
+ """
+ if not self.stats_enabled:
+ yield self.db.updates._end_background_update("populate_stats_process_users")
+ return 1
+
+ last_user_id = progress.get("last_user_id", "")
+
+ def _get_next_batch(txn):
+ sql = """
+ SELECT DISTINCT name FROM users
+ WHERE name > ?
+ ORDER BY name ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_user_id, batch_size))
+ return [r for r, in txn]
+
+ users_to_work_on = yield self.db.runInteraction(
+ "_populate_stats_process_users", _get_next_batch
+ )
+
+ # No more rooms -- complete the transaction.
+ if not users_to_work_on:
+ yield self.db.updates._end_background_update("populate_stats_process_users")
+ return 1
+
+ for user_id in users_to_work_on:
+ yield self._calculate_and_set_initial_state_for_user(user_id)
+ progress["last_user_id"] = user_id
+
+ yield self.db.runInteraction(
+ "populate_stats_process_users",
+ self.db.updates._background_update_progress_txn,
+ "populate_stats_process_users",
+ progress,
+ )
+
+ return len(users_to_work_on)
+
+ @defer.inlineCallbacks
+ def _populate_stats_process_rooms(self, progress, batch_size):
+ """
+ This is a background update which regenerates statistics for rooms.
+ """
+ if not self.stats_enabled:
+ yield self.db.updates._end_background_update("populate_stats_process_rooms")
+ return 1
+
+ last_room_id = progress.get("last_room_id", "")
+
+ def _get_next_batch(txn):
+ sql = """
+ SELECT DISTINCT room_id FROM current_state_events
+ WHERE room_id > ?
+ ORDER BY room_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_room_id, batch_size))
+ return [r for r, in txn]
+
+ rooms_to_work_on = yield self.db.runInteraction(
+ "populate_stats_rooms_get_batch", _get_next_batch
+ )
+
+ # No more rooms -- complete the transaction.
+ if not rooms_to_work_on:
+ yield self.db.updates._end_background_update("populate_stats_process_rooms")
+ return 1
+
+ for room_id in rooms_to_work_on:
+ yield self._calculate_and_set_initial_state_for_room(room_id)
+ progress["last_room_id"] = room_id
+
+ yield self.db.runInteraction(
+ "_populate_stats_process_rooms",
+ self.db.updates._background_update_progress_txn,
+ "populate_stats_process_rooms",
+ progress,
+ )
+
+ return len(rooms_to_work_on)
+
+ def get_stats_positions(self):
+ """
+ Returns the stats processor positions.
+ """
+ return self.db.simple_select_one_onecol(
+ table="stats_incremental_position",
+ keyvalues={},
+ retcol="stream_id",
+ desc="stats_incremental_position",
+ )
+
+ def update_room_state(self, room_id, fields):
+ """
+ Args:
+ room_id (str)
+ fields (dict[str:Any])
+ """
+
+ # For whatever reason some of the fields may contain null bytes, which
+ # postgres isn't a fan of, so we replace those fields with null.
+ for col in (
+ "join_rules",
+ "history_visibility",
+ "encryption",
+ "name",
+ "topic",
+ "avatar",
+ "canonical_alias",
+ ):
+ field = fields.get(col)
+ if field and "\0" in field:
+ fields[col] = None
+
+ return self.db.simple_upsert(
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ values=fields,
+ desc="update_room_state",
+ )
+
+ def get_statistics_for_subject(self, stats_type, stats_id, start, size=100):
+ """
+ Get statistics for a given subject.
+
+ Args:
+ stats_type (str): The type of subject
+ stats_id (str): The ID of the subject (e.g. room_id or user_id)
+ start (int): Pagination start. Number of entries, not timestamp.
+ size (int): How many entries to return.
+
+ Returns:
+ Deferred[list[dict]], where the dict has the keys of
+ ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
+ """
+ return self.db.runInteraction(
+ "get_statistics_for_subject",
+ self._get_statistics_for_subject_txn,
+ stats_type,
+ stats_id,
+ start,
+ size,
+ )
+
+ def _get_statistics_for_subject_txn(
+ self, txn, stats_type, stats_id, start, size=100
+ ):
+ """
+ Transaction-bound version of L{get_statistics_for_subject}.
+ """
+
+ table, id_col = TYPE_TO_TABLE[stats_type]
+ selected_columns = list(
+ ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
+ )
+
+ slice_list = self.db.simple_select_list_paginate_txn(
+ txn,
+ table + "_historical",
+ "end_ts",
+ start,
+ size,
+ retcols=selected_columns + ["bucket_size", "end_ts"],
+ keyvalues={id_col: stats_id},
+ order_direction="DESC",
+ )
+
+ return slice_list
+
+ @cached()
+ def get_earliest_token_for_stats(self, stats_type, id):
+ """
+ Fetch the "earliest token". This is used by the room stats delta
+ processor to ignore deltas that have been processed between the
+ start of the background task and any particular room's stats
+ being calculated.
+
+ Returns:
+ Deferred[int]
+ """
+ table, id_col = TYPE_TO_TABLE[stats_type]
+
+ return self.db.simple_select_one_onecol(
+ "%s_current" % (table,),
+ keyvalues={id_col: id},
+ retcol="completed_delta_stream_id",
+ allow_none=True,
+ )
+
+ def bulk_update_stats_delta(self, ts, updates, stream_id):
+ """Bulk update stats tables for a given stream_id and updates the stats
+ incremental position.
+
+ Args:
+ ts (int): Current timestamp in ms
+ updates(dict[str, dict[str, dict[str, Counter]]]): The updates to
+ commit as a mapping stats_type -> stats_id -> field -> delta.
+ stream_id (int): Current position.
+
+ Returns:
+ Deferred
+ """
+
+ def _bulk_update_stats_delta_txn(txn):
+ for stats_type, stats_updates in updates.items():
+ for stats_id, fields in stats_updates.items():
+ logger.debug(
+ "Updating %s stats for %s: %s", stats_type, stats_id, fields
+ )
+ self._update_stats_delta_txn(
+ txn,
+ ts=ts,
+ stats_type=stats_type,
+ stats_id=stats_id,
+ fields=fields,
+ complete_with_stream_id=stream_id,
+ )
+
+ self.db.simple_update_one_txn(
+ txn,
+ table="stats_incremental_position",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ )
+
+ return self.db.runInteraction(
+ "bulk_update_stats_delta", _bulk_update_stats_delta_txn
+ )
+
+ def update_stats_delta(
+ self,
+ ts,
+ stats_type,
+ stats_id,
+ fields,
+ complete_with_stream_id,
+ absolute_field_overrides=None,
+ ):
+ """
+ Updates the statistics for a subject, with a delta (difference/relative
+ change).
+
+ Args:
+ ts (int): timestamp of the change
+ stats_type (str): "room" or "user" – the kind of subject
+ stats_id (str): the subject's ID (room ID or user ID)
+ fields (dict[str, int]): Deltas of stats values.
+ complete_with_stream_id (int, optional):
+ If supplied, converts an incomplete row into a complete row,
+ with the supplied stream_id marked as the stream_id where the
+ row was completed.
+ absolute_field_overrides (dict[str, int]): Current stats values
+ (i.e. not deltas) of absolute fields.
+ Does not work with per-slice fields.
+ """
+
+ return self.db.runInteraction(
+ "update_stats_delta",
+ self._update_stats_delta_txn,
+ ts,
+ stats_type,
+ stats_id,
+ fields,
+ complete_with_stream_id=complete_with_stream_id,
+ absolute_field_overrides=absolute_field_overrides,
+ )
+
+ def _update_stats_delta_txn(
+ self,
+ txn,
+ ts,
+ stats_type,
+ stats_id,
+ fields,
+ complete_with_stream_id,
+ absolute_field_overrides=None,
+ ):
+ if absolute_field_overrides is None:
+ absolute_field_overrides = {}
+
+ table, id_col = TYPE_TO_TABLE[stats_type]
+
+ quantised_ts = self.quantise_stats_time(int(ts))
+ end_ts = quantised_ts + self.stats_bucket_size
+
+ # Lets be paranoid and check that all the given field names are known
+ abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type]
+ slice_field_names = PER_SLICE_FIELDS[stats_type]
+ for field in chain(fields.keys(), absolute_field_overrides.keys()):
+ if field not in abs_field_names and field not in slice_field_names:
+ # guard against potential SQL injection dodginess
+ raise ValueError(
+ "%s is not a recognised field"
+ " for stats type %s" % (field, stats_type)
+ )
+
+ # Per slice fields do not get added to the _current table
+
+ # This calculates the deltas (`field = field + ?` values)
+ # for absolute fields,
+ # * defaulting to 0 if not specified
+ # (required for the INSERT part of upserting to work)
+ # * omitting overrides specified in `absolute_field_overrides`
+ deltas_of_absolute_fields = {
+ key: fields.get(key, 0)
+ for key in abs_field_names
+ if key not in absolute_field_overrides
+ }
+
+ # Keep the delta stream ID field up to date
+ absolute_field_overrides = absolute_field_overrides.copy()
+ absolute_field_overrides["completed_delta_stream_id"] = complete_with_stream_id
+
+ # first upsert the `_current` table
+ self._upsert_with_additive_relatives_txn(
+ txn=txn,
+ table=table + "_current",
+ keyvalues={id_col: stats_id},
+ absolutes=absolute_field_overrides,
+ additive_relatives=deltas_of_absolute_fields,
+ )
+
+ per_slice_additive_relatives = {
+ key: fields.get(key, 0) for key in slice_field_names
+ }
+ self._upsert_copy_from_table_with_additive_relatives_txn(
+ txn=txn,
+ into_table=table + "_historical",
+ keyvalues={id_col: stats_id},
+ extra_dst_insvalues={"bucket_size": self.stats_bucket_size},
+ extra_dst_keyvalues={"end_ts": end_ts},
+ additive_relatives=per_slice_additive_relatives,
+ src_table=table + "_current",
+ copy_columns=abs_field_names,
+ )
+
+ def _upsert_with_additive_relatives_txn(
+ self, txn, table, keyvalues, absolutes, additive_relatives
+ ):
+ """Used to update values in the stats tables.
+
+ This is basically a slightly convoluted upsert that *adds* to any
+ existing rows.
+
+ Args:
+ txn
+ table (str): Table name
+ keyvalues (dict[str, any]): Row-identifying key values
+ absolutes (dict[str, any]): Absolute (set) fields
+ additive_relatives (dict[str, int]): Fields that will be added onto
+ if existing row present.
+ """
+ if self.database_engine.can_native_upsert:
+ absolute_updates = [
+ "%(field)s = EXCLUDED.%(field)s" % {"field": field}
+ for field in absolutes.keys()
+ ]
+
+ relative_updates = [
+ "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s"
+ % {"table": table, "field": field}
+ for field in additive_relatives.keys()
+ ]
+
+ insert_cols = []
+ qargs = []
+
+ for (key, val) in chain(
+ keyvalues.items(), absolutes.items(), additive_relatives.items()
+ ):
+ insert_cols.append(key)
+ qargs.append(val)
+
+ sql = """
+ INSERT INTO %(table)s (%(insert_cols_cs)s)
+ VALUES (%(insert_vals_qs)s)
+ ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s
+ """ % {
+ "table": table,
+ "insert_cols_cs": ", ".join(insert_cols),
+ "insert_vals_qs": ", ".join(
+ ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives))
+ ),
+ "key_columns": ", ".join(keyvalues),
+ "updates": ", ".join(chain(absolute_updates, relative_updates)),
+ }
+
+ txn.execute(sql, qargs)
+ else:
+ self.database_engine.lock_table(txn, table)
+ retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
+ current_row = self.db.simple_select_one_txn(
+ txn, table, keyvalues, retcols, allow_none=True
+ )
+ if current_row is None:
+ merged_dict = {**keyvalues, **absolutes, **additive_relatives}
+ self.db.simple_insert_txn(txn, table, merged_dict)
+ else:
+ for (key, val) in additive_relatives.items():
+ current_row[key] += val
+ current_row.update(absolutes)
+ self.db.simple_update_one_txn(txn, table, keyvalues, current_row)
+
+ def _upsert_copy_from_table_with_additive_relatives_txn(
+ self,
+ txn,
+ into_table,
+ keyvalues,
+ extra_dst_keyvalues,
+ extra_dst_insvalues,
+ additive_relatives,
+ src_table,
+ copy_columns,
+ ):
+ """Updates the historic stats table with latest updates.
+
+ This involves copying "absolute" fields from the `_current` table, and
+ adding relative fields to any existing values.
+
+ Args:
+ txn: Transaction
+ into_table (str): The destination table to UPSERT the row into
+ keyvalues (dict[str, any]): Row-identifying key values
+ extra_dst_keyvalues (dict[str, any]): Additional keyvalues
+ for `into_table`.
+ extra_dst_insvalues (dict[str, any]): Additional values to insert
+ on new row creation for `into_table`.
+ additive_relatives (dict[str, any]): Fields that will be added onto
+ if existing row present. (Must be disjoint from copy_columns.)
+ src_table (str): The source table to copy from
+ copy_columns (iterable[str]): The list of columns to copy
+ """
+ if self.database_engine.can_native_upsert:
+ ins_columns = chain(
+ keyvalues,
+ copy_columns,
+ additive_relatives,
+ extra_dst_keyvalues,
+ extra_dst_insvalues,
+ )
+ sel_exprs = chain(
+ keyvalues,
+ copy_columns,
+ (
+ "?"
+ for _ in chain(
+ additive_relatives, extra_dst_keyvalues, extra_dst_insvalues
+ )
+ ),
+ )
+ keyvalues_where = ("%s = ?" % f for f in keyvalues)
+
+ sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns)
+ sets_ar = (
+ "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f)
+ for f in additive_relatives
+ )
+
+ sql = """
+ INSERT INTO %(into_table)s (%(ins_columns)s)
+ SELECT %(sel_exprs)s
+ FROM %(src_table)s
+ WHERE %(keyvalues_where)s
+ ON CONFLICT (%(keyvalues)s)
+ DO UPDATE SET %(sets)s
+ """ % {
+ "into_table": into_table,
+ "ins_columns": ", ".join(ins_columns),
+ "sel_exprs": ", ".join(sel_exprs),
+ "keyvalues_where": " AND ".join(keyvalues_where),
+ "src_table": src_table,
+ "keyvalues": ", ".join(
+ chain(keyvalues.keys(), extra_dst_keyvalues.keys())
+ ),
+ "sets": ", ".join(chain(sets_cc, sets_ar)),
+ }
+
+ qargs = list(
+ chain(
+ additive_relatives.values(),
+ extra_dst_keyvalues.values(),
+ extra_dst_insvalues.values(),
+ keyvalues.values(),
+ )
+ )
+ txn.execute(sql, qargs)
+ else:
+ self.database_engine.lock_table(txn, into_table)
+ src_row = self.db.simple_select_one_txn(
+ txn, src_table, keyvalues, copy_columns
+ )
+ all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
+ dest_current_row = self.db.simple_select_one_txn(
+ txn,
+ into_table,
+ keyvalues=all_dest_keyvalues,
+ retcols=list(chain(additive_relatives.keys(), copy_columns)),
+ allow_none=True,
+ )
+
+ if dest_current_row is None:
+ merged_dict = {
+ **keyvalues,
+ **extra_dst_keyvalues,
+ **extra_dst_insvalues,
+ **src_row,
+ **additive_relatives,
+ }
+ self.db.simple_insert_txn(txn, into_table, merged_dict)
+ else:
+ for (key, val) in additive_relatives.items():
+ src_row[key] = dest_current_row[key] + val
+ self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
+
+ def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
+ """Fetches the counts of events in the given range of stream IDs.
+
+ Args:
+ min_pos (int)
+ max_pos (int)
+
+ Returns:
+ Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field
+ changes.
+ """
+
+ return self.db.runInteraction(
+ "stats_incremental_total_events_and_bytes",
+ self.get_changes_room_total_events_and_bytes_txn,
+ min_pos,
+ max_pos,
+ )
+
+ def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+ """Gets the total_events and total_event_bytes counts for rooms and
+ senders, in a range of stream_orderings (including backfilled events).
+
+ Args:
+ txn
+ low_pos (int): Low stream ordering
+ high_pos (int): High stream ordering
+
+ Returns:
+ tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
+ room and user deltas for total_events/total_event_bytes in the
+ format of `stats_id` -> fields
+ """
+
+ if low_pos >= high_pos:
+ # nothing to do here.
+ return {}, {}
+
+ if isinstance(self.database_engine, PostgresEngine):
+ new_bytes_expression = "OCTET_LENGTH(json)"
+ else:
+ new_bytes_expression = "LENGTH(CAST(json AS BLOB))"
+
+ sql = """
+ SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes
+ FROM events INNER JOIN event_json USING (event_id)
+ WHERE (? < stream_ordering AND stream_ordering <= ?)
+ OR (? <= stream_ordering AND stream_ordering <= ?)
+ GROUP BY events.room_id
+ """ % (
+ new_bytes_expression,
+ )
+
+ txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos))
+
+ room_deltas = {
+ room_id: {"total_events": new_events, "total_event_bytes": new_bytes}
+ for room_id, new_events, new_bytes in txn
+ }
+
+ sql = """
+ SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes
+ FROM events INNER JOIN event_json USING (event_id)
+ WHERE (? < stream_ordering AND stream_ordering <= ?)
+ OR (? <= stream_ordering AND stream_ordering <= ?)
+ GROUP BY events.sender
+ """ % (
+ new_bytes_expression,
+ )
+
+ txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos))
+
+ user_deltas = {
+ user_id: {"total_events": new_events, "total_event_bytes": new_bytes}
+ for user_id, new_events, new_bytes in txn
+ if self.hs.is_mine_id(user_id)
+ }
+
+ return room_deltas, user_deltas
+
+ @defer.inlineCallbacks
+ def _calculate_and_set_initial_state_for_room(self, room_id):
+ """Calculate and insert an entry into room_stats_current.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[tuple[dict, dict, int]]: A tuple of room state, membership
+ counts and stream position.
+ """
+
+ def _fetch_current_state_stats(txn):
+ pos = self.get_room_max_stream_ordering()
+
+ rows = self.db.simple_select_many_txn(
+ txn,
+ table="current_state_events",
+ column="type",
+ iterable=[
+ EventTypes.Create,
+ EventTypes.JoinRules,
+ EventTypes.RoomHistoryVisibility,
+ EventTypes.RoomEncryption,
+ EventTypes.Name,
+ EventTypes.Topic,
+ EventTypes.RoomAvatar,
+ EventTypes.CanonicalAlias,
+ ],
+ keyvalues={"room_id": room_id, "state_key": ""},
+ retcols=["event_id"],
+ )
+
+ event_ids = [row["event_id"] for row in rows]
+
+ txn.execute(
+ """
+ SELECT membership, count(*) FROM current_state_events
+ WHERE room_id = ? AND type = 'm.room.member'
+ GROUP BY membership
+ """,
+ (room_id,),
+ )
+ membership_counts = {membership: cnt for membership, cnt in txn}
+
+ txn.execute(
+ """
+ SELECT COALESCE(count(*), 0) FROM current_state_events
+ WHERE room_id = ?
+ """,
+ (room_id,),
+ )
+
+ (current_state_events_count,) = txn.fetchone()
+
+ users_in_room = self.get_users_in_room_txn(txn, room_id)
+
+ return (
+ event_ids,
+ membership_counts,
+ current_state_events_count,
+ users_in_room,
+ pos,
+ )
+
+ (
+ event_ids,
+ membership_counts,
+ current_state_events_count,
+ users_in_room,
+ pos,
+ ) = yield self.db.runInteraction(
+ "get_initial_state_for_room", _fetch_current_state_stats
+ )
+
+ state_event_map = yield self.get_events(event_ids, get_prev_content=False)
+
+ room_state = {
+ "join_rules": None,
+ "history_visibility": None,
+ "encryption": None,
+ "name": None,
+ "topic": None,
+ "avatar": None,
+ "canonical_alias": None,
+ "is_federatable": True,
+ }
+
+ for event in state_event_map.values():
+ if event.type == EventTypes.JoinRules:
+ room_state["join_rules"] = event.content.get("join_rule")
+ elif event.type == EventTypes.RoomHistoryVisibility:
+ room_state["history_visibility"] = event.content.get(
+ "history_visibility"
+ )
+ elif event.type == EventTypes.RoomEncryption:
+ room_state["encryption"] = event.content.get("algorithm")
+ elif event.type == EventTypes.Name:
+ room_state["name"] = event.content.get("name")
+ elif event.type == EventTypes.Topic:
+ room_state["topic"] = event.content.get("topic")
+ elif event.type == EventTypes.RoomAvatar:
+ room_state["avatar"] = event.content.get("url")
+ elif event.type == EventTypes.CanonicalAlias:
+ room_state["canonical_alias"] = event.content.get("alias")
+ elif event.type == EventTypes.Create:
+ room_state["is_federatable"] = (
+ event.content.get("m.federate", True) is True
+ )
+
+ yield self.update_room_state(room_id, room_state)
+
+ local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
+
+ yield self.update_stats_delta(
+ ts=self.clock.time_msec(),
+ stats_type="room",
+ stats_id=room_id,
+ fields={},
+ complete_with_stream_id=pos,
+ absolute_field_overrides={
+ "current_state_events": current_state_events_count,
+ "joined_members": membership_counts.get(Membership.JOIN, 0),
+ "invited_members": membership_counts.get(Membership.INVITE, 0),
+ "left_members": membership_counts.get(Membership.LEAVE, 0),
+ "banned_members": membership_counts.get(Membership.BAN, 0),
+ "local_users_in_room": len(local_users_in_room),
+ },
+ )
+
+ @defer.inlineCallbacks
+ def _calculate_and_set_initial_state_for_user(self, user_id):
+ def _calculate_and_set_initial_state_for_user_txn(txn):
+ pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
+
+ txn.execute(
+ """
+ SELECT COUNT(distinct room_id) FROM current_state_events
+ WHERE type = 'm.room.member' AND state_key = ?
+ AND membership = 'join'
+ """,
+ (user_id,),
+ )
+ (count,) = txn.fetchone()
+ return count, pos
+
+ joined_rooms, pos = yield self.db.runInteraction(
+ "calculate_and_set_initial_state_for_user",
+ _calculate_and_set_initial_state_for_user_txn,
+ )
+
+ yield self.update_stats_delta(
+ ts=self.clock.time_msec(),
+ stats_type="user",
+ stats_id=user_id,
+ fields={},
+ complete_with_stream_id=pos,
+ absolute_field_overrides={"joined_rooms": joined_rooms},
+ )
diff --git a/synapse/storage/stream.py b/synapse/storage/data_stores/main/stream.py
index 6f7f65d96b..ada5cce6c2 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -41,12 +44,13 @@ from six.moves import range
from twisted.internet import defer
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
-from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -65,7 +69,7 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
- direction, column_names, from_token, to_token, engine,
+ direction, column_names, from_token, to_token, engine
):
"""Creates an SQL expression to bound the columns by the pagination
tokens.
@@ -153,7 +157,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
str
"""
- assert(bound in (">", "<", ">=", "<="))
+ assert bound in (">", "<", ">=", "<=")
name1, name2 = column_names
val1, val2 = values
@@ -169,11 +173,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
- return "((%d,%d) %s (%s,%s))" % (
- val1, val2,
- bound,
- name1, name2,
- )
+ return "((%d,%d) %s (%s,%s))" % (val1, val2, bound, name1, name2)
# We want to generate queries of e.g. the form:
#
@@ -233,6 +233,14 @@ def filter_to_clause(event_filter):
clauses.append("contains_url = ?")
args.append(event_filter.contains_url)
+ # We're only applying the "labels" filter on the database query, because applying the
+ # "not_labels" filter via a SQL query is non-trivial. Instead, we let
+ # event_filter.check_fields apply it, which is not as efficient but makes the
+ # implementation simpler.
+ if event_filter.labels:
+ clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
+ args.extend(event_filter.labels)
+
return " AND ".join(clauses), args
@@ -244,11 +252,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
- def __init__(self, db_conn, hs):
- super(StreamWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(StreamWorkerStore, self).__init__(database, db_conn, hs)
events_max = self.get_room_max_stream_ordering()
- event_cache_prefill, min_event_val = self._get_cache_dict(
+ event_cache_prefill, min_event_val = self.db.get_cache_dict(
db_conn,
"events",
entity_column="room_id",
@@ -276,7 +284,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(
- self, room_ids, from_key, to_key, limit=0, order='DESC'
+ self, room_ids, from_key, to_key, limit=0, order="DESC"
):
"""Get new room events in stream ordering since `from_key`.
@@ -304,7 +312,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
if not room_ids:
- defer.returnValue({})
+ return {}
results = {}
room_ids = list(room_ids)
@@ -327,7 +335,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
results.update(dict(zip(rm_ids, res)))
- defer.returnValue(results)
+ return results
def get_rooms_that_changed(self, room_ids, from_key):
"""Given a list of rooms and a token, return rooms where there may have
@@ -338,15 +346,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_key (str): The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
- return set(
+ return {
room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
- )
+ }
@defer.inlineCallbacks
def get_room_events_stream_for_room(
- self, room_id, from_key, to_key, limit=0, order='DESC'
+ self, room_id, from_key, to_key, limit=0, order="DESC"
):
"""Get new room events in stream ordering since `from_key`.
@@ -368,7 +376,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
the chunk of events returned.
"""
if from_key == to_key:
- defer.returnValue(([], from_key))
+ return [], from_key
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -378,7 +386,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
if not has_changed:
- defer.returnValue(([], from_key))
+ return [], from_key
def f(txn):
sql = (
@@ -393,10 +401,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
- rows = yield self.runInteraction("get_room_events_stream_for_room", f)
+ rows = yield self.db.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self.get_events_as_list([
- r.event_id for r in rows], get_prev_content=True,
+ ret = yield self.get_events_as_list(
+ [r.event_id for r in rows], get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=from_id is None)
@@ -411,7 +419,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# get.
key = from_key
- defer.returnValue((ret, key))
+ return ret, key
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
@@ -419,14 +427,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
- defer.returnValue([])
+ return []
if from_id:
has_changed = self._membership_stream_cache.has_entity_changed(
user_id, int(from_id)
)
if not has_changed:
- defer.returnValue([])
+ return []
def f(txn):
sql = (
@@ -443,15 +451,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
- rows = yield self.runInteraction("get_membership_changes_for_user", f)
+ rows = yield self.db.runInteraction("get_membership_changes_for_user", f)
ret = yield self.get_events_as_list(
- [r.event_id for r in rows], get_prev_content=True,
+ [r.event_id for r in rows], get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=False)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token):
@@ -481,7 +489,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(events, rows)
- defer.returnValue((events, token))
+ return (events, token)
@defer.inlineCallbacks
def get_recent_event_ids_for_room(self, room_id, limit, end_token):
@@ -500,11 +508,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
# Allow a zero limit here, and no-op.
if limit == 0:
- defer.returnValue(([], end_token))
+ return [], end_token
end_token = RoomStreamToken.parse(end_token)
- rows, token = yield self.runInteraction(
+ rows, token = yield self.db.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -515,10 +523,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# We want to return the results in ascending order.
rows.reverse()
- defer.returnValue((rows, token))
+ return rows, token
- def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
- """Gets details of the first event in a room at or after a stream ordering
+ def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+ """Gets details of the first event in a room at or before a stream ordering
Args:
room_id (str):
@@ -533,15 +541,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
- " WHERE room_id = ? AND stream_ordering >= ?"
+ " WHERE room_id = ? AND stream_ordering <= ?"
" AND NOT outlier"
- " ORDER BY stream_ordering"
+ " ORDER BY stream_ordering DESC"
" LIMIT 1"
)
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
- return self.runInteraction("get_room_event_after_stream_ordering", _f)
+ return self.db.runInteraction("get_room_event_before_stream_ordering", _f)
@defer.inlineCallbacks
def get_room_events_max_id(self, room_id=None):
@@ -553,12 +561,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
token = yield self.get_room_max_stream_ordering()
if room_id is None:
- defer.returnValue("s%d" % (token,))
+ return "s%d" % (token,)
else:
- topo = yield self.runInteraction(
+ topo = yield self.db.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- defer.returnValue("t%d-%d" % (topo, token))
+ return "t%d-%d" % (topo, token)
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
@@ -569,7 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "s%d" stream token.
"""
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,))
@@ -582,7 +590,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "t%d-%d" topological token.
"""
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
@@ -606,13 +614,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
- return self._execute(
+ return self.db.execute(
"get_max_topological_token", None, sql, room_id, stream_key
).addCallback(lambda r: r[0][0] if r else 0)
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
- "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
+ "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
)
@@ -660,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- results = yield self.runInteraction(
+ results = yield self.db.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -671,21 +679,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
events_before = yield self.get_events_as_list(
- [e for e in results["before"]["event_ids"]], get_prev_content=True
+ list(results["before"]["event_ids"]), get_prev_content=True
)
events_after = yield self.get_events_as_list(
- [e for e in results["after"]["event_ids"]], get_prev_content=True
+ list(results["after"]["event_ids"]), get_prev_content=True
)
- defer.returnValue(
- {
- "events_before": events_before,
- "events_after": events_after,
- "start": results["before"]["token"],
- "end": results["after"]["token"],
- }
- )
+ return {
+ "events_before": events_before,
+ "events_after": events_after,
+ "start": results["before"]["token"],
+ "end": results["after"]["token"],
+ }
def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter
@@ -704,7 +710,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- results = self._simple_select_one_txn(
+ results = self.db.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -725,7 +731,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
before_token,
- direction='b',
+ direction="b",
limit=before_limit,
event_filter=event_filter,
)
@@ -735,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
after_token,
- direction='f',
+ direction="f",
limit=after_limit,
event_filter=event_filter,
)
@@ -783,16 +789,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.runInteraction(
+ upper_bound, event_ids = yield self.db.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
events = yield self.get_events_as_list(event_ids)
- defer.returnValue((upper_bound, events))
+ return upper_bound, events
def get_federation_out_pos(self, typ):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
keyvalues={"type": typ},
@@ -800,7 +806,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
def update_federation_out_pos(self, typ, stream_id):
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ},
updatevalues={"stream_id": stream_id},
@@ -816,7 +822,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id,
from_token,
to_token=None,
- direction='b',
+ direction="b",
limit=-1,
event_filter=None,
):
@@ -837,7 +843,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
as a list of _EventDictReturn and a token that points to the end
- of the result set.
+ of the result set. If no events are returned then the end of the
+ stream has been reached (i.e. there are no events between
+ `from_token` and `to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
args = [False, room_id]
- if direction == 'b':
+ if direction == "b":
order = "DESC"
else:
order = "ASC"
@@ -867,13 +875,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args.append(int(limit))
- sql = (
- "SELECT event_id, topological_ordering, stream_ordering"
- " FROM events"
- " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
- " ORDER BY topological_ordering %(order)s,"
- " stream_ordering %(order)s LIMIT ?"
- ) % {"bounds": bounds, "order": order}
+ select_keywords = "SELECT"
+ join_clause = ""
+ if event_filter and event_filter.labels:
+ # If we're not filtering on a label, then joining on event_labels will
+ # return as many row for a single event as the number of labels it has. To
+ # avoid this, only join if we're filtering on at least one label.
+ join_clause = """
+ LEFT JOIN event_labels
+ USING (event_id, room_id, topological_ordering)
+ """
+ if len(event_filter.labels) > 1:
+ # Using DISTINCT in this SELECT query is quite expensive, because it
+ # requires the engine to sort on the entire (not limited) result set,
+ # i.e. the entire events table. We only need to use it when we're
+ # filtering on more than two labels, because that's the only scenario
+ # in which we can possibly to get multiple times the same event ID in
+ # the results.
+ select_keywords += "DISTINCT"
+
+ sql = """
+ %(select_keywords)s event_id, topological_ordering, stream_ordering
+ FROM events
+ %(join_clause)s
+ WHERE outlier = ? AND room_id = ? AND %(bounds)s
+ ORDER BY topological_ordering %(order)s,
+ stream_ordering %(order)s LIMIT ?
+ """ % {
+ "select_keywords": select_keywords,
+ "join_clause": join_clause,
+ "bounds": bounds,
+ "order": order,
+ }
txn.execute(sql, args)
@@ -882,7 +915,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if rows:
topo = rows[-1].topological_ordering
toke = rows[-1].stream_ordering
- if direction == 'b':
+ if direction == "b":
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk
@@ -898,7 +931,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def paginate_room_events(
- self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None
+ self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
):
"""Returns list of events before or after a given token.
@@ -909,22 +942,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
only those before
direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return. Zero or less
- means no limit.
+ limit (int): The maximum number of events to return.
event_filter (Filter|None): If provided filters the events to
those that match the filter.
Returns:
- tuple[list[dict], str]: Returns the results as a list of dicts and
- a token that points to the end of the result set. The dicts have
- the keys "event_id", "topological_ordering" and "stream_orderign".
+ tuple[list[FrozenEvent], str]: Returns the results as a list of
+ events and a token that points to the end of the result set. If no
+ events are returned then the end of the stream has been reached
+ (i.e. there are no events between `from_key` and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
- rows, token = yield self.runInteraction(
+ rows, token = yield self.db.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
@@ -941,7 +974,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(events, rows)
- defer.returnValue((events, token))
+ return (events, token)
class StreamStore(StreamWorkerStore):
diff --git a/synapse/storage/tags.py b/synapse/storage/data_stores/main/tags.py
index e88f8ea35f..2aa1bafd48 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -22,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag strings to tag content.
"""
- deferred = self._simple_select_list(
+ deferred = self.db.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
@@ -66,7 +66,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
room_id string, tag string and content string.
"""
if last_id == current_id:
- defer.returnValue([])
+ return []
def get_all_updated_tags_txn(txn):
sql = (
@@ -78,14 +78,12 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- tag_ids = yield self.runInteraction(
+ tag_ids = yield self.db.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
def get_tag_content(txn, tag_ids):
- sql = (
- "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
- )
+ sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
@@ -100,14 +98,14 @@ class TagsWorkerStore(AccountDataWorkerStore):
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
- tags = yield self.runInteraction(
+ tags = yield self.db.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
)
results.extend(tags)
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
@@ -135,9 +133,11 @@ class TagsWorkerStore(AccountDataWorkerStore):
user_id, int(stream_id)
)
if not changed:
- defer.returnValue({})
+ return {}
- room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
+ room_ids = yield self.db.runInteraction(
+ "get_updated_tags", get_updated_tags_txn
+ )
results = {}
if room_ids:
@@ -145,7 +145,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})
- defer.returnValue(results)
+ return results
def get_tags_for_room(self, user_id, room_id):
"""Get all the tags for the given room
@@ -155,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A deferred list of string tags.
"""
- return self._simple_select_list(
+ return self.db.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
@@ -180,7 +180,7 @@ class TagsStore(TagsWorkerStore):
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
@@ -189,12 +189,12 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
- yield self.runInteraction("add_tag", add_tag_txn, next_id)
+ yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
@@ -212,12 +212,12 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
- yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
+ yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
def _update_revision_txn(self, txn, user_id, room_id, next_id):
"""Update the latest revision of the tags for the given user and room.
diff --git a/synapse/storage/transactions.py b/synapse/storage/data_stores/main/transactions.py
index b1188f6bcb..5b07c2fbc0 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -23,10 +23,10 @@ from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
-from ._base import SQLBaseStore, db_to_json
-
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
@@ -53,8 +53,8 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
- def __init__(self, db_conn, hs):
- super(TransactionStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(TransactionStore, self).__init__(database, db_conn, hs)
self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
@@ -78,7 +78,7 @@ class TransactionStore(SQLBaseStore):
this transaction or a 2-tuple of (int, dict)
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_received_txn_response",
self._get_received_txn_response,
transaction_id,
@@ -86,7 +86,7 @@ class TransactionStore(SQLBaseStore):
)
def _get_received_txn_response(self, txn, transaction_id, origin):
- result = self._simple_select_one_txn(
+ result = self.db.simple_select_one_txn(
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
@@ -120,7 +120,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
- return self._simple_insert(
+ return self.db.simple_insert(
table="received_transactions",
values={
"transaction_id": transaction_id,
@@ -133,34 +133,6 @@ class TransactionStore(SQLBaseStore):
desc="set_received_txn_response",
)
- def prep_send_transaction(self, transaction_id, destination, origin_server_ts):
- """Persists an outgoing transaction and calculates the values for the
- previous transaction id list.
-
- This should be called before sending the transaction so that it has the
- correct value for the `prev_ids` key.
-
- Args:
- transaction_id (str)
- destination (str)
- origin_server_ts (int)
-
- Returns:
- list: A list of previous transaction ids.
- """
- return defer.succeed([])
-
- def delivered_txn(self, transaction_id, destination, code, response_dict):
- """Persists the response for an outgoing transaction.
-
- Args:
- transaction_id (str)
- destination (str)
- code (int)
- response_json (str)
- """
- pass
-
@defer.inlineCallbacks
def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
@@ -175,9 +147,9 @@ class TransactionStore(SQLBaseStore):
result = self._destination_retry_cache.get(destination, SENTINEL)
if result is not SENTINEL:
- defer.returnValue(result)
+ return result
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings,
destination,
@@ -186,14 +158,14 @@ class TransactionStore(SQLBaseStore):
# We don't hugely care about race conditions between getting and
# invalidating the cache, since we time out fairly quickly anyway.
self._destination_retry_cache[destination] = result
- defer.returnValue(result)
+ return result
def _get_destination_retry_timings(self, txn, destination):
- result = self._simple_select_one_txn(
+ result = self.db.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
- retcols=("destination", "retry_last_ts", "retry_interval"),
+ retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
allow_none=True,
)
@@ -202,82 +174,91 @@ class TransactionStore(SQLBaseStore):
else:
return None
- def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval):
+ def set_destination_retry_timings(
+ self, destination, failure_ts, retry_last_ts, retry_interval
+ ):
"""Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring.
Args:
destination (str)
+ failure_ts (int|None) - when the server started failing (ms since epoch)
retry_last_ts (int) - time of last retry attempt in unix epoch ms
retry_interval (int) - how long until next retry in ms
"""
self._destination_retry_cache.pop(destination, None)
- return self.runInteraction(
+ return self.db.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
+ failure_ts,
retry_last_ts,
retry_interval,
)
def _set_destination_retry_timings(
- self, txn, destination, retry_last_ts, retry_interval
+ self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
+
+ if self.database_engine.can_native_upsert:
+ # Upsert retry time interval if retry_interval is zero (i.e. we're
+ # resetting it) or greater than the existing retry interval.
+
+ sql = """
+ INSERT INTO destinations (
+ destination, failure_ts, retry_last_ts, retry_interval
+ )
+ VALUES (?, ?, ?, ?)
+ ON CONFLICT (destination) DO UPDATE SET
+ failure_ts = EXCLUDED.failure_ts,
+ retry_last_ts = EXCLUDED.retry_last_ts,
+ retry_interval = EXCLUDED.retry_interval
+ WHERE
+ EXCLUDED.retry_interval = 0
+ OR destinations.retry_interval < EXCLUDED.retry_interval
+ """
+
+ txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
+
+ return
+
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
# due to a worker setting the timings.
- prev_row = self._simple_select_one_txn(
+ prev_row = self.db.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
- retcols=("retry_last_ts", "retry_interval"),
+ retcols=("failure_ts", "retry_last_ts", "retry_interval"),
allow_none=True,
)
if not prev_row:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="destinations",
values={
"destination": destination,
+ "failure_ts": failure_ts,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
},
)
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
"destinations",
keyvalues={"destination": destination},
updatevalues={
+ "failure_ts": failure_ts,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
},
)
- def get_destinations_needing_retry(self):
- """Get all destinations which are due a retry for sending a transaction.
-
- Returns:
- list: A list of dicts
- """
-
- return self.runInteraction(
- "get_destinations_needing_retry", self._get_destinations_needing_retry
- )
-
- def _get_destinations_needing_retry(self, txn):
- query = (
- "SELECT * FROM destinations"
- " WHERE retry_last_ts > 0 and retry_next_ts < ?"
- )
-
- txn.execute(query, (self._clock.time_msec(),))
- return self.cursor_to_dict(txn)
-
def _start_cleanup_transactions(self):
return run_as_background_process(
"cleanup_transactions", self._cleanup_transactions
@@ -290,4 +271,6 @@ class TransactionStore(SQLBaseStore):
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
- return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn)
+ return self.db.runInteraction(
+ "_cleanup_transactions", _cleanup_transactions_txn
+ )
diff --git a/synapse/storage/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 83466e25d9..6b8130bf0f 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -19,10 +19,10 @@ import re
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.data_stores.main.state import StateFilter
+from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.storage.state import StateFilter
-from synapse.storage.state_deltas import StateDeltasStore
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
@@ -32,30 +32,30 @@ logger = logging.getLogger(__name__)
TEMP_TABLE = "_temp_populate_user_directory"
-class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
+class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, db_conn, hs):
- super(UserDirectoryStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.server_name = hs.hostname
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"populate_user_directory_createtables",
self._populate_user_directory_createtables,
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"populate_user_directory_process_rooms",
self._populate_user_directory_process_rooms,
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"populate_user_directory_process_users",
self._populate_user_directory_process_users,
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)
@@ -85,7 +85,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"""
txn.execute(sql)
rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
- self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
# If search all users is on, get all the users we want to add.
@@ -100,23 +100,25 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
txn.execute("SELECT name FROM users")
users = [{"user_id": x[0]} for x in txn.fetchall()]
- self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
new_pos = yield self.get_max_stream_id_in_current_state_deltas()
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory_temp_build", _make_staging_area
)
- yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+ yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
- yield self._end_background_update("populate_user_directory_createtables")
- defer.returnValue(1)
+ yield self.db.updates._end_background_update(
+ "populate_user_directory_createtables"
+ )
+ return 1
@defer.inlineCallbacks
def _populate_user_directory_cleanup(self, progress, batch_size):
"""
Update the user directory stream position, then clean up the old tables.
"""
- position = yield self._simple_select_one_onecol(
+ position = yield self.db.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
)
yield self.update_user_directory_stream_pos(position)
@@ -126,12 +128,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory_cleanup", _delete_staging_area
)
- yield self._end_background_update("populate_user_directory_cleanup")
- defer.returnValue(1)
+ yield self.db.updates._end_background_update("populate_user_directory_cleanup")
+ return 1
@defer.inlineCallbacks
def _populate_user_directory_process_rooms(self, progress, batch_size):
@@ -170,16 +172,18 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
return rooms_to_work_on
- rooms_to_work_on = yield self.runInteraction(
+ rooms_to_work_on = yield self.db.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
- yield self._end_background_update("populate_user_directory_process_rooms")
- defer.returnValue(1)
+ yield self.db.updates._end_background_update(
+ "populate_user_directory_process_rooms"
+ )
+ return 1
- logger.info(
+ logger.debug(
"Processing the next %d rooms of %d remaining"
% (len(rooms_to_work_on), progress["remaining"])
)
@@ -243,12 +247,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
to_insert.clear()
# We've finished a room. Delete it from the table.
- yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+ yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory",
- self._background_update_progress_txn,
+ self.db.updates._background_update_progress_txn,
"populate_user_directory_process_rooms",
progress,
)
@@ -257,9 +261,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
if processed_event_count > batch_size:
# Don't process any more rooms, we've hit our batch size.
- defer.returnValue(processed_event_count)
+ return processed_event_count
- defer.returnValue(processed_event_count)
+ return processed_event_count
@defer.inlineCallbacks
def _populate_user_directory_process_users(self, progress, batch_size):
@@ -267,8 +271,10 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
If search_all_users is enabled, add all of the users to the user directory.
"""
if not self.hs.config.user_directory_search_all_users:
- yield self._end_background_update("populate_user_directory_process_users")
- defer.returnValue(1)
+ yield self.db.updates._end_background_update(
+ "populate_user_directory_process_users"
+ )
+ return 1
def _get_next_batch(txn):
sql = "SELECT user_id FROM %s LIMIT %s" % (
@@ -291,16 +297,18 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
return users_to_work_on
- users_to_work_on = yield self.runInteraction(
+ users_to_work_on = yield self.db.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more users -- complete the transaction.
if not users_to_work_on:
- yield self._end_background_update("populate_user_directory_process_users")
- defer.returnValue(1)
+ yield self.db.updates._end_background_update(
+ "populate_user_directory_process_users"
+ )
+ return 1
- logger.info(
+ logger.debug(
"Processing the next %d users of %d remaining"
% (len(users_to_work_on), progress["remaining"])
)
@@ -312,17 +320,17 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
)
# We've finished processing a user. Delete it from the table.
- yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
+ yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory",
- self._background_update_progress_txn,
+ self.db.updates._background_update_progress_txn,
"populate_user_directory_process_users",
progress,
)
- defer.returnValue(len(users_to_work_on))
+ return len(users_to_work_on)
@defer.inlineCallbacks
def is_room_world_readable_or_publicly_joinable(self, room_id):
@@ -344,16 +352,16 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
- defer.returnValue(True)
+ return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
if hist_vis_ev.content.get("history_visibility") == "world_readable":
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
"""
@@ -361,7 +369,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"""
def _update_profile_in_user_dir_txn(txn):
- new_entry = self._simple_upsert_txn(
+ new_entry = self.db.simple_upsert_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
@@ -435,7 +443,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
)
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name) if display_name else user_id
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="user_directory_search",
keyvalues={"user_id": user_id},
@@ -448,59 +456,10 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.runInteraction(
+ return self.db.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
- def remove_from_user_dir(self, user_id):
- def _remove_from_user_dir_txn(txn):
- self._simple_delete_txn(
- txn, table="user_directory", keyvalues={"user_id": user_id}
- )
- self._simple_delete_txn(
- txn, table="user_directory_search", keyvalues={"user_id": user_id}
- )
- self._simple_delete_txn(
- txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
- )
- self._simple_delete_txn(
- txn,
- table="users_who_share_private_rooms",
- keyvalues={"user_id": user_id},
- )
- self._simple_delete_txn(
- txn,
- table="users_who_share_private_rooms",
- keyvalues={"other_user_id": user_id},
- )
- txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
-
- return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
-
- @defer.inlineCallbacks
- def get_users_in_dir_due_to_room(self, room_id):
- """Get all user_ids that are in the room directory because they're
- in the given room_id
- """
- user_ids_share_pub = yield self._simple_select_onecol(
- table="users_in_public_rooms",
- keyvalues={"room_id": room_id},
- retcol="user_id",
- desc="get_users_in_dir_due_to_room",
- )
-
- user_ids_share_priv = yield self._simple_select_onecol(
- table="users_who_share_private_rooms",
- keyvalues={"room_id": room_id},
- retcol="other_user_id",
- desc="get_users_in_dir_due_to_room",
- )
-
- user_ids = set(user_ids_share_pub)
- user_ids.update(user_ids_share_priv)
-
- defer.returnValue(user_ids)
-
def add_users_who_share_private_room(self, room_id, user_id_tuples):
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
@@ -511,7 +470,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"""
def _add_users_who_share_room_txn(txn):
- self._simple_upsert_many_txn(
+ self.db.simple_upsert_many_txn(
txn,
table="users_who_share_private_rooms",
key_names=["user_id", "other_user_id", "room_id"],
@@ -523,7 +482,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
value_values=None,
)
- return self.runInteraction(
+ return self.db.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
@@ -538,7 +497,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
def _add_users_in_public_rooms_txn(txn):
- self._simple_upsert_many_txn(
+ self.db.simple_upsert_many_txn(
txn,
table="users_in_public_rooms",
key_names=["user_id", "room_id"],
@@ -547,10 +506,102 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
value_values=None,
)
- return self.runInteraction(
+ return self.db.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
+ def delete_all_from_user_dir(self):
+ """Delete the entire user directory
+ """
+
+ def _delete_all_from_user_dir_txn(txn):
+ txn.execute("DELETE FROM user_directory")
+ txn.execute("DELETE FROM user_directory_search")
+ txn.execute("DELETE FROM users_in_public_rooms")
+ txn.execute("DELETE FROM users_who_share_private_rooms")
+ txn.call_after(self.get_user_in_directory.invalidate_all)
+
+ return self.db.runInteraction(
+ "delete_all_from_user_dir", _delete_all_from_user_dir_txn
+ )
+
+ @cached()
+ def get_user_in_directory(self, user_id):
+ return self.db.simple_select_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ retcols=("display_name", "avatar_url"),
+ allow_none=True,
+ desc="get_user_in_directory",
+ )
+
+ def update_user_directory_stream_pos(self, stream_id):
+ return self.db.simple_update_one(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ desc="update_user_directory_stream_pos",
+ )
+
+
+class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
+
+ # How many records do we calculate before sending it to
+ # add_users_who_share_private_rooms?
+ SHARE_PRIVATE_WORKING_SET = 500
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(UserDirectoryStore, self).__init__(database, db_conn, hs)
+
+ def remove_from_user_dir(self, user_id):
+ def _remove_from_user_dir_txn(txn):
+ self.db.simple_delete_txn(
+ txn, table="user_directory", keyvalues={"user_id": user_id}
+ )
+ self.db.simple_delete_txn(
+ txn, table="user_directory_search", keyvalues={"user_id": user_id}
+ )
+ self.db.simple_delete_txn(
+ txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
+ )
+ self.db.simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"user_id": user_id},
+ )
+ self.db.simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"other_user_id": user_id},
+ )
+ txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+
+ return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
+
+ @defer.inlineCallbacks
+ def get_users_in_dir_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory because they're
+ in the given room_id
+ """
+ user_ids_share_pub = yield self.db.simple_select_onecol(
+ table="users_in_public_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids_share_priv = yield self.db.simple_select_onecol(
+ table="users_who_share_private_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="other_user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids = set(user_ids_share_pub)
+ user_ids.update(user_ids_share_priv)
+
+ return user_ids
+
def remove_user_who_share_room(self, user_id, room_id):
"""
Deletes entries in the users_who_share_*_rooms table. The first
@@ -562,23 +613,23 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"""
def _remove_user_who_share_room_txn(txn):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id, "room_id": room_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_in_public_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
@@ -593,14 +644,14 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
Returns:
list: user_id
"""
- rows = yield self._simple_select_onecol(
+ rows = yield self.db.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
desc="get_rooms_user_is_in",
)
- pub_rows = yield self._simple_select_onecol(
+ pub_rows = yield self.db.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
@@ -609,7 +660,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
users = set(pub_rows)
users.update(rows)
- defer.returnValue(list(users))
+ return list(users)
@defer.inlineCallbacks
def get_rooms_in_common_for_users(self, user_id, other_user_id):
@@ -618,66 +669,33 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
sql = """
SELECT room_id FROM (
SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
+ INNER JOIN room_memberships AS m USING (event_id)
WHERE type = 'm.room.member'
- AND membership = 'join'
+ AND m.membership = 'join'
AND state_key = ?
) AS f1 INNER JOIN (
SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
+ INNER JOIN room_memberships AS m USING (event_id)
WHERE type = 'm.room.member'
- AND membership = 'join'
+ AND m.membership = 'join'
AND state_key = ?
) f2 USING (room_id)
"""
- rows = yield self._execute(
+ rows = yield self.db.execute(
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
)
- defer.returnValue([room_id for room_id, in rows])
-
- def delete_all_from_user_dir(self):
- """Delete the entire user directory
- """
-
- def _delete_all_from_user_dir_txn(txn):
- txn.execute("DELETE FROM user_directory")
- txn.execute("DELETE FROM user_directory_search")
- txn.execute("DELETE FROM users_in_public_rooms")
- txn.execute("DELETE FROM users_who_share_private_rooms")
- txn.call_after(self.get_user_in_directory.invalidate_all)
-
- return self.runInteraction(
- "delete_all_from_user_dir", _delete_all_from_user_dir_txn
- )
-
- @cached()
- def get_user_in_directory(self, user_id):
- return self._simple_select_one(
- table="user_directory",
- keyvalues={"user_id": user_id},
- retcols=("display_name", "avatar_url"),
- allow_none=True,
- desc="get_user_in_directory",
- )
+ return [room_id for room_id, in rows]
def get_user_directory_stream_pos(self):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
desc="get_user_directory_stream_pos",
)
- def update_user_directory_stream_pos(self, stream_id):
- return self._simple_update_one(
- table="user_directory_stream_pos",
- keyvalues={},
- updatevalues={"stream_id": stream_id},
- desc="update_user_directory_stream_pos",
- )
-
@defer.inlineCallbacks
def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory
@@ -776,13 +794,13 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = yield self._execute(
- "search_user_dir", self.cursor_to_dict, sql, *args
+ results = yield self.db.execute(
+ "search_user_dir", self.db.cursor_to_dict, sql, *args
)
limited = len(results) > limit
- defer.returnValue({"limited": limited, "results": results})
+ return {"limited": limited, "results": results}
def _parse_query_sqlite(search_term):
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index 1815fdc0dd..ec6b8a4ffd 100644
--- a/synapse/storage/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import operator
-from twisted.internet import defer
+import operator
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@@ -32,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if the user has requested erasure
"""
- return self._simple_select_onecol(
+ return self.db.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
@@ -57,17 +56,17 @@ class UserErasureWorkerStore(SQLBaseStore):
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- def _get_erased_users(txn):
- txn.execute(
- "SELECT user_id FROM erased_users WHERE user_id IN (%s)"
- % (",".join("?" * len(user_ids))),
- user_ids,
- )
- return set(r[0] for r in txn)
+ rows = yield self.db.simple_select_many_batch(
+ table="erased_users",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="are_users_erased",
+ )
+ erased_users = {row["user_id"] for row in rows}
- erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
- res = dict((u, u in erased_users) for u in user_ids)
- defer.returnValue(res)
+ res = {u: u in erased_users for u in user_ids}
+ return res
class UserErasureStore(UserErasureWorkerStore):
@@ -89,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.runInteraction("mark_user_erased", f)
+ return self.db.runInteraction("mark_user_erased", f)
diff --git a/synapse/storage/data_stores/state/__init__.py b/synapse/storage/data_stores/state/__init__.py
new file mode 100644
index 0000000000..86e09f6229
--- /dev/null
+++ b/synapse/storage/data_stores/state/__init__.py
@@ -0,0 +1,16 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py
new file mode 100644
index 0000000000..e8edaf9f7b
--- /dev/null
+++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -0,0 +1,374 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import iteritems
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.state import StateFilter
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class StateGroupBackgroundUpdateStore(SQLBaseStore):
+ """Defines functions related to state groups needed to run the state backgroud
+ updates.
+ """
+
+ def _count_state_group_hops_txn(self, txn, state_group):
+ """Given a state group, count how many hops there are in the tree.
+
+ This is used to ensure the delta chains don't get too long.
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT count(*) FROM state;
+ """
+
+ txn.execute(sql, (state_group,))
+ row = txn.fetchone()
+ if row and row[0]:
+ return row[0]
+ else:
+ return 0
+ else:
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ next_group = state_group
+ count = 0
+
+ while next_group:
+ next_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+ if next_group:
+ count += 1
+
+ return count
+
+ def _get_state_groups_from_groups_txn(
+ self, txn, groups, state_filter=StateFilter.all()
+ ):
+ results = {group: {} for group in groups}
+
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # Temporarily disable sequential scans in this transaction. This is
+ # a temporary hack until we can add the right indices in
+ txn.execute("SET LOCAL enable_seqscan=off")
+
+ # The below query walks the state_group tree so that the "state"
+ # table includes all state_groups in the tree. It then joins
+ # against `state_groups_state` to fetch the latest state.
+ # It assumes that previous state groups are always numerically
+ # lesser.
+ # The PARTITION is used to get the event_id in the greatest state
+ # group for the given type, state_key.
+ # This may return multiple rows per (type, state_key), but last_value
+ # should be the same.
+ sql = """
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT DISTINCT type, state_key, last_value(event_id) OVER (
+ PARTITION BY type, state_key ORDER BY state_group ASC
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ ) AS event_id FROM state_groups_state
+ WHERE state_group IN (
+ SELECT state_group FROM state
+ )
+ """
+
+ for group in groups:
+ args = [group]
+ args.extend(where_args)
+
+ txn.execute(sql + where_clause, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (typ, state_key)
+ results[group][key] = event_id
+ else:
+ max_entries_returned = state_filter.max_entries_returned()
+
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ for group in groups:
+ next_group = group
+
+ while next_group:
+ # We did this before by getting the list of group ids, and
+ # then passing that list to sqlite to get latest event for
+ # each (type, state_key). However, that was terribly slow
+ # without the right indices (which we can't add until
+ # after we finish deduping state, which requires this func)
+ args = [next_group]
+ args.extend(where_args)
+
+ txn.execute(
+ "SELECT type, state_key, event_id FROM state_groups_state"
+ " WHERE state_group = ? " + where_clause,
+ args,
+ )
+ results[group].update(
+ ((typ, state_key), event_id)
+ for typ, state_key, event_id in txn
+ if (typ, state_key) not in results[group]
+ )
+
+ # If the number of entries in the (type,state_key)->event_id dict
+ # matches the number of (type,state_keys) types we were searching
+ # for, then we must have found them all, so no need to go walk
+ # further down the tree... UNLESS our types filter contained
+ # wildcards (i.e. Nones) in which case we have to do an exhaustive
+ # search
+ if (
+ max_entries_returned is not None
+ and len(results[group]) == max_entries_returned
+ ):
+ break
+
+ next_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ return results
+
+
+class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
+
+ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+ STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+ STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ self.db.updates.register_background_update_handler(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+ self._background_deduplicate_state,
+ )
+ self.db.updates.register_background_update_handler(
+ self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
+ )
+ self.db.updates.register_background_index_update(
+ self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME,
+ index_name="state_groups_room_id_idx",
+ table="state_groups",
+ columns=["room_id"],
+ )
+
+ @defer.inlineCallbacks
+ def _background_deduplicate_state(self, progress, batch_size):
+ """This background update will slowly deduplicate state by reencoding
+ them as deltas.
+ """
+ last_state_group = progress.get("last_state_group", 0)
+ rows_inserted = progress.get("rows_inserted", 0)
+ max_group = progress.get("max_group", None)
+
+ BATCH_SIZE_SCALE_FACTOR = 100
+
+ batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
+
+ if max_group is None:
+ rows = yield self.db.execute(
+ "_background_deduplicate_state",
+ None,
+ "SELECT coalesce(max(id), 0) FROM state_groups",
+ )
+ max_group = rows[0][0]
+
+ def reindex_txn(txn):
+ new_last_state_group = last_state_group
+ for count in range(batch_size):
+ txn.execute(
+ "SELECT id, room_id FROM state_groups"
+ " WHERE ? < id AND id <= ?"
+ " ORDER BY id ASC"
+ " LIMIT 1",
+ (new_last_state_group, max_group),
+ )
+ row = txn.fetchone()
+ if row:
+ state_group, room_id = row
+
+ if not row or not state_group:
+ return True, count
+
+ txn.execute(
+ "SELECT state_group FROM state_group_edges"
+ " WHERE state_group = ?",
+ (state_group,),
+ )
+
+ # If we reach a point where we've already started inserting
+ # edges we should stop.
+ if txn.fetchall():
+ return True, count
+
+ txn.execute(
+ "SELECT coalesce(max(id), 0) FROM state_groups"
+ " WHERE id < ? AND room_id = ?",
+ (state_group, room_id),
+ )
+ (prev_group,) = txn.fetchone()
+ new_last_state_group = state_group
+
+ if prev_group:
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if potential_hops >= MAX_STATE_DELTA_HOPS:
+ # We want to ensure chains are at most this long,#
+ # otherwise read performance degrades.
+ continue
+
+ prev_state = self._get_state_groups_from_groups_txn(
+ txn, [prev_group]
+ )
+ prev_state = prev_state[prev_group]
+
+ curr_state = self._get_state_groups_from_groups_txn(
+ txn, [state_group]
+ )
+ curr_state = curr_state[state_group]
+
+ if not set(prev_state.keys()) - set(curr_state.keys()):
+ # We can only do a delta if the current has a strict super set
+ # of keys
+
+ delta_state = {
+ key: value
+ for key, value in iteritems(curr_state)
+ if prev_state.get(key, None) != value
+ }
+
+ self.db.simple_delete_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": state_group},
+ )
+
+ self.db.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={
+ "state_group": state_group,
+ "prev_state_group": prev_group,
+ },
+ )
+
+ self.db.simple_delete_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(delta_state)
+ ],
+ )
+
+ progress = {
+ "last_state_group": state_group,
+ "rows_inserted": rows_inserted + batch_size,
+ "max_group": max_group,
+ }
+
+ self.db.updates._background_update_progress_txn(
+ txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
+ )
+
+ return False, batch_size
+
+ finished, result = yield self.db.runInteraction(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
+ )
+
+ if finished:
+ yield self.db.updates._end_background_update(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
+ )
+
+ return result * BATCH_SIZE_SCALE_FACTOR
+
+ @defer.inlineCallbacks
+ def _background_index_state(self, progress, batch_size):
+ def reindex_txn(conn):
+ conn.rollback()
+ if isinstance(self.database_engine, PostgresEngine):
+ # postgres insists on autocommit for the index
+ conn.set_session(autocommit=True)
+ try:
+ txn = conn.cursor()
+ txn.execute(
+ "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
+ " ON state_groups_state(state_group, type, state_key)"
+ )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+ finally:
+ conn.set_session(autocommit=False)
+ else:
+ txn = conn.cursor()
+ txn.execute(
+ "CREATE INDEX state_groups_state_type_idx"
+ " ON state_groups_state(state_group, type, state_key)"
+ )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+
+ yield self.db.runWithConnection(reindex_txn)
+
+ yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+
+ return 1
diff --git a/synapse/storage/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
index ae09fa0065..ae09fa0065 100644
--- a/synapse/storage/schema/delta/23/drop_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
index e85699e82e..e85699e82e 100644
--- a/synapse/storage/schema/delta/30/state_stream.sql
+++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
new file mode 100644
index 0000000000..1450313bfa
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
@@ -0,0 +1,19 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- The following indices are redundant, other indices are equivalent or
+-- supersets
+DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
diff --git a/synapse/storage/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
index 0fce26345b..33980d02f0 100644
--- a/synapse/storage/schema/delta/35/add_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
@@ -13,8 +13,5 @@
* limitations under the License.
*/
-
-ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
-
INSERT into background_updates (update_name, progress_json, depends_on)
VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication');
diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql
index 0f1fa68a89..0f1fa68a89 100644
--- a/synapse/storage/schema/delta/35/state.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql
diff --git a/synapse/storage/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
index 97e5067ef4..97e5067ef4 100644
--- a/synapse/storage/schema/delta/35/state_dedupe.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
index f6766501d2..9fd1ccf6f7 100644
--- a/synapse/storage/schema/delta/47/state_group_seq.py
+++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
@@ -27,10 +27,7 @@ def run_create(cur, database_engine, *args, **kwargs):
else:
start_val = row[0] + 1
- cur.execute(
- "CREATE SEQUENCE state_group_id_seq START WITH %s",
- (start_val, ),
- )
+ cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,))
def run_upgrade(*args, **kwargs):
diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
new file mode 100644
index 0000000000..7916ef18b2
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('state_groups_room_id_idx', '{}');
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
new file mode 100644
index 0000000000..35f97d6b3d
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
@@ -0,0 +1,37 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE state_groups (
+ id BIGINT PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_groups_state (
+ state_group BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_group_edges (
+ state_group BIGINT NOT NULL,
+ prev_state_group BIGINT NOT NULL
+);
+
+CREATE INDEX state_group_edges_idx ON state_group_edges (state_group);
+CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group);
+CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key);
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
new file mode 100644
index 0000000000..fcd926c9fb
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE state_group_id_seq
+ START WITH 1
+ INCREMENT BY 1
+ NO MINVALUE
+ NO MAXVALUE
+ CACHE 1;
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
new file mode 100644
index 0000000000..57a5267663
--- /dev/null
+++ b/synapse/storage/data_stores/state/store.py
@@ -0,0 +1,644 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import namedtuple
+from typing import Dict, Iterable, List, Set, Tuple
+
+from six import iteritems
+from six.moves import range
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
+from synapse.storage.database import Database
+from synapse.storage.state import StateFilter
+from synapse.types import StateMap
+from synapse.util.caches import get_cache_factor_for
+from synapse.util.caches.descriptors import cached
+from synapse.util.caches.dictionary_cache import DictionaryCache
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class _GetStateGroupDelta(
+ namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+
+ __slots__ = []
+
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
+
+
+class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
+ """A data store for fetching/storing state groups.
+ """
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+
+ # Originally the state store used a single DictionaryCache to cache the
+ # event IDs for the state types in a given state group to avoid hammering
+ # on the state_group* tables.
+ #
+ # The point of using a DictionaryCache is that it can cache a subset
+ # of the state events for a given state group (i.e. a subset of the keys for a
+ # given dict which is an entry in the cache for a given state group ID).
+ #
+ # However, this poses problems when performing complicated queries
+ # on the store - for instance: "give me all the state for this group, but
+ # limit members to this subset of users", as DictionaryCache's API isn't
+ # rich enough to say "please cache any of these fields, apart from this subset".
+ # This is problematic when lazy loading members, which requires this behaviour,
+ # as without it the cache has no choice but to speculatively load all
+ # state events for the group, which negates the efficiency being sought.
+ #
+ # Rather than overcomplicating DictionaryCache's API, we instead split the
+ # state_group_cache into two halves - one for tracking non-member events,
+ # and the other for tracking member_events. This means that lazy loading
+ # queries can be made in a cache-friendly manner by querying both caches
+ # separately and then merging the result. So for the example above, you
+ # would query the members cache for a specific subset of state keys
+ # (which DictionaryCache will handle efficiently and fine) and the non-members
+ # cache for all state (which DictionaryCache will similarly handle fine)
+ # and then just merge the results together.
+ #
+ # We size the non-members cache to be smaller than the members cache as the
+ # vast majority of state in Matrix (today) is member events.
+
+ self._state_group_cache = DictionaryCache(
+ "*stateGroupCache*",
+ # TODO: this hasn't been tuned yet
+ 50000 * get_cache_factor_for("stateGroupCache"),
+ )
+ self._state_group_members_cache = DictionaryCache(
+ "*stateGroupMembersCache*",
+ 500000 * get_cache_factor_for("stateGroupMembersCache"),
+ )
+
+ @cached(max_entries=10000, iterable=True)
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ (prev_group, delta_ids), where both may be None.
+ """
+
+ def _get_state_group_delta_txn(txn):
+ prev_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": state_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ if not prev_group:
+ return _GetStateGroupDelta(None, None)
+
+ delta_ids = self.db.simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
+ )
+
+ return _GetStateGroupDelta(
+ prev_group,
+ {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ )
+
+ return self.db.runInteraction(
+ "get_state_group_delta", _get_state_group_delta_txn
+ )
+
+ @defer.inlineCallbacks
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ):
+ """Returns the state groups for a given set of groups from the
+ database, filtering on types of state events.
+
+ Args:
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+ """
+ results = {}
+
+ chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
+ for chunk in chunks:
+ res = yield self.db.runInteraction(
+ "_get_state_groups_from_groups",
+ self._get_state_groups_from_groups_txn,
+ chunk,
+ state_filter,
+ )
+ results.update(res)
+
+ return results
+
+ def _get_state_for_group_using_cache(self, cache, group, state_filter):
+ """Checks if group is in cache. See `_get_state_for_groups`
+
+ Args:
+ cache(DictionaryCache): the state group cache to use
+ group(int): The state group to lookup
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns 2-tuple (`state_dict`, `got_all`).
+ `got_all` is a bool indicating if we successfully retrieved all
+ requests state from the cache, if False we need to query the DB for the
+ missing state.
+ """
+ is_all, known_absent, state_dict_ids = cache.get(group)
+
+ if is_all or state_filter.is_full():
+ # Either we have everything or want everything, either way
+ # `is_all` tells us whether we've gotten everything.
+ return state_filter.filter_state(state_dict_ids), is_all
+
+ # tracks whether any of our requested types are missing from the cache
+ missing_types = False
+
+ if state_filter.has_wildcards():
+ # We don't know if we fetched all the state keys for the types in
+ # the filter that are wildcards, so we have to assume that we may
+ # have missed some.
+ missing_types = True
+ else:
+ # There aren't any wild cards, so `concrete_types()` returns the
+ # complete list of event types we're wanting.
+ for key in state_filter.concrete_types():
+ if key not in state_dict_ids and key not in known_absent:
+ missing_types = True
+ break
+
+ return state_filter.filter_state(state_dict_ids), not missing_types
+
+ @defer.inlineCallbacks
+ def _get_state_for_groups(
+ self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ ):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups: list of state groups for which we want
+ to get the state.
+ state_filter: The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+ """
+
+ member_filter, non_member_filter = state_filter.get_member_split()
+
+ # Now we look them up in the member and non-member caches
+ (
+ non_member_state,
+ incomplete_groups_nm,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache, state_filter=non_member_filter
+ )
+
+ (
+ member_state,
+ incomplete_groups_m,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_members_cache, state_filter=member_filter
+ )
+
+ state = dict(non_member_state)
+ for group in groups:
+ state[group].update(member_state[group])
+
+ # Now fetch any missing groups from the database
+
+ incomplete_groups = incomplete_groups_m | incomplete_groups_nm
+
+ if not incomplete_groups:
+ return state
+
+ cache_sequence_nm = self._state_group_cache.sequence
+ cache_sequence_m = self._state_group_members_cache.sequence
+
+ # Help the cache hit ratio by expanding the filter a bit
+ db_state_filter = state_filter.return_expanded()
+
+ group_to_state_dict = yield self._get_state_groups_from_groups(
+ list(incomplete_groups), state_filter=db_state_filter
+ )
+
+ # Now lets update the caches
+ self._insert_into_cache(
+ group_to_state_dict,
+ db_state_filter,
+ cache_seq_num_members=cache_sequence_m,
+ cache_seq_num_non_members=cache_sequence_nm,
+ )
+
+ # And finally update the result dict, by filtering out any extra
+ # stuff we pulled out of the database.
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ # We just replace any existing entries, as we will have loaded
+ # everything we need from the database anyway.
+ state[group] = state_filter.filter_state(group_state_dict)
+
+ return state
+
+ def _get_state_for_groups_using_cache(
+ self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
+ ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key, querying from a specific cache.
+
+ Args:
+ groups: list of state groups for which we want to get the state.
+ cache: the cache of group ids to state dicts which
+ we will pass through - either the normal state cache or the
+ specific members state cache.
+ state_filter: The state filter used to fetch state from the
+ database.
+
+ Returns:
+ Tuple of dict of state_group_id to state map of entries in the
+ cache, and the state group ids either missing from the cache or
+ incomplete.
+ """
+ results = {}
+ incomplete_groups = set()
+ for group in set(groups):
+ state_dict_ids, got_all = self._get_state_for_group_using_cache(
+ cache, group, state_filter
+ )
+ results[group] = state_dict_ids
+
+ if not got_all:
+ incomplete_groups.add(group)
+
+ return results, incomplete_groups
+
+ def _insert_into_cache(
+ self,
+ group_to_state_dict,
+ state_filter,
+ cache_seq_num_members,
+ cache_seq_num_non_members,
+ ):
+ """Inserts results from querying the database into the relevant cache.
+
+ Args:
+ group_to_state_dict (dict): The new entries pulled from database.
+ Map from state group to state dict
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ cache_seq_num_members (int): Sequence number of member cache since
+ last lookup in cache
+ cache_seq_num_non_members (int): Sequence number of member cache since
+ last lookup in cache
+ """
+
+ # We need to work out which types we've fetched from the DB for the
+ # member vs non-member caches. This should be as accurate as possible,
+ # but can be an underestimate (e.g. when we have wild cards)
+
+ member_filter, non_member_filter = state_filter.get_member_split()
+ if member_filter.is_full():
+ # We fetched all member events
+ member_types = None
+ else:
+ # `concrete_types()` will only return a subset when there are wild
+ # cards in the filter, but that's fine.
+ member_types = member_filter.concrete_types()
+
+ if non_member_filter.is_full():
+ # We fetched all non member events
+ non_member_types = None
+ else:
+ non_member_types = non_member_filter.concrete_types()
+
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ state_dict_members = {}
+ state_dict_non_members = {}
+
+ for k, v in iteritems(group_state_dict):
+ if k[0] == EventTypes.Member:
+ state_dict_members[k] = v
+ else:
+ state_dict_non_members[k] = v
+
+ self._state_group_members_cache.update(
+ cache_seq_num_members,
+ key=group,
+ value=state_dict_members,
+ fetched_keys=member_types,
+ )
+
+ self._state_group_cache.update(
+ cache_seq_num_non_members,
+ key=group,
+ value=state_dict_non_members,
+ fetched_keys=non_member_types,
+ )
+
+ def store_state_group(
+ self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ ):
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id (str): The event ID for which the state was calculated
+ room_id (str)
+ prev_group (int|None): A previous state group for the room, optional.
+ delta_ids (dict|None): The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids (dict): The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ Deferred[int]: The state group ID
+ """
+
+ def _store_state_group_txn(txn):
+ if current_state_ids is None:
+ # AFAIK, this can never happen
+ raise Exception("current_state_ids cannot be None")
+
+ state_group = self.database_engine.get_next_state_group_id(txn)
+
+ self.db.simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={"id": state_group, "room_id": room_id, "event_id": event_id},
+ )
+
+ # We persist as a delta if we can, while also ensuring the chain
+ # of deltas isn't tooo long, as otherwise read performance degrades.
+ if prev_group:
+ is_in_db = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ self.db.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={"state_group": state_group, "prev_state_group": prev_group},
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(delta_ids)
+ ],
+ )
+ else:
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(current_state_ids)
+ ],
+ )
+
+ # Prefill the state group caches with this group.
+ # It's fine to use the sequence like this as the state group map
+ # is immutable. (If the map wasn't immutable then this prefill could
+ # race with another update)
+
+ current_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] == EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_members_cache.update,
+ self._state_group_members_cache.sequence,
+ key=state_group,
+ value=dict(current_member_state_ids),
+ )
+
+ current_non_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] != EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_cache.update,
+ self._state_group_cache.sequence,
+ key=state_group,
+ value=dict(current_non_member_state_ids),
+ )
+
+ return state_group
+
+ return self.db.runInteraction("store_state_group", _store_state_group_txn)
+
+ def purge_unreferenced_state_groups(
+ self, room_id: str, state_groups_to_delete
+ ) -> defer.Deferred:
+ """Deletes no longer referenced state groups and de-deltas any state
+ groups that reference them.
+
+ Args:
+ room_id: The room the state groups belong to (must all be in the
+ same room).
+ state_groups_to_delete (Collection[int]): Set of all state groups
+ to delete.
+ """
+
+ return self.db.runInteraction(
+ "purge_unreferenced_state_groups",
+ self._purge_unreferenced_state_groups,
+ room_id,
+ state_groups_to_delete,
+ )
+
+ def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+ logger.info(
+ "[purge] found %i state groups to delete", len(state_groups_to_delete)
+ )
+
+ rows = self.db.simple_select_many_txn(
+ txn,
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ retcols=("state_group",),
+ )
+
+ remaining_state_groups = {
+ row["state_group"]
+ for row in rows
+ if row["state_group"] not in state_groups_to_delete
+ }
+
+ logger.info(
+ "[purge] de-delta-ing %i remaining state groups",
+ len(remaining_state_groups),
+ )
+
+ # Now we turn the state groups that reference to-be-deleted state
+ # groups to non delta versions.
+ for sg in remaining_state_groups:
+ logger.info("[purge] de-delta-ing remaining state group %s", sg)
+ curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
+ curr_state = curr_state[sg]
+
+ self.db.simple_delete_txn(
+ txn, table="state_groups_state", keyvalues={"state_group": sg}
+ )
+
+ self.db.simple_delete_txn(
+ txn, table="state_group_edges", keyvalues={"state_group": sg}
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": sg,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(curr_state)
+ ],
+ )
+
+ logger.info("[purge] removing redundant state groups")
+ txn.executemany(
+ "DELETE FROM state_groups_state WHERE state_group = ?",
+ ((sg,) for sg in state_groups_to_delete),
+ )
+ txn.executemany(
+ "DELETE FROM state_groups WHERE id = ?",
+ ((sg,) for sg in state_groups_to_delete),
+ )
+
+ @defer.inlineCallbacks
+ def get_previous_state_groups(self, state_groups):
+ """Fetch the previous groups of the given state groups.
+
+ Args:
+ state_groups (Iterable[int])
+
+ Returns:
+ Deferred[dict[int, int]]: mapping from state group to previous
+ state group.
+ """
+
+ rows = yield self.db.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("prev_state_group", "state_group"),
+ desc="get_previous_state_groups",
+ )
+
+ return {row["state_group"]: row["prev_state_group"] for row in rows}
+
+ def purge_room_state(self, room_id, state_groups_to_delete):
+ """Deletes all record of a room from state tables
+
+ Args:
+ room_id (str):
+ state_groups_to_delete (list[int]): State groups to delete
+ """
+
+ return self.db.runInteraction(
+ "purge_room_state",
+ self._purge_room_state_txn,
+ room_id,
+ state_groups_to_delete,
+ )
+
+ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+ # first we have to delete the state groups states
+ logger.info("[purge] removing %s from state_groups_state", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_groups_state",
+ column="state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ # ... and the state group edges
+ logger.info("[purge] removing %s from state_group_edges", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_group_edges",
+ column="state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ # ... and the state groups
+ logger.info("[purge] removing %s from state_groups", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_groups",
+ column="id",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
new file mode 100644
index 0000000000..e61595336c
--- /dev/null
+++ b/synapse/storage/database.py
@@ -0,0 +1,1560 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import time
+from time import monotonic as monotonic_time
+from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
+
+from six import iteritems, iterkeys, itervalues
+from six.moves import intern, range
+
+from prometheus_client import Histogram
+
+from twisted.enterprise import adbapi
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
+from synapse.logging.context import (
+ LoggingContext,
+ LoggingContextOrSentinel,
+ make_deferred_yieldable,
+)
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.background_updates import BackgroundUpdater
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.types import Connection, Cursor
+from synapse.util.stringutils import exception_to_unicode
+
+logger = logging.getLogger(__name__)
+
+# python 3 does not have a maximum int value
+MAX_TXN_ID = 2 ** 63 - 1
+
+sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
+perf_logger = logging.getLogger("synapse.storage.TIME")
+
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
+
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+
+
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+ "user_ips": "user_ips_device_unique_index",
+ "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+ "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+ "event_search": "event_search_event_id_idx",
+}
+
+
+def make_pool(
+ reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> adbapi.ConnectionPool:
+ """Get the connection pool for the database.
+ """
+
+ return adbapi.ConnectionPool(
+ db_config.config["name"],
+ cp_reactor=reactor,
+ cp_openfun=engine.on_new_connection,
+ **db_config.config.get("args", {})
+ )
+
+
+def make_conn(
+ db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> Connection:
+ """Make a new connection to the database and return it.
+
+ Returns:
+ Connection
+ """
+
+ db_params = {
+ k: v
+ for k, v in db_config.config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = engine.module.connect(**db_params)
+ engine.on_new_connection(db_conn)
+ return db_conn
+
+
+# The type of entry which goes on our after_callbacks and exception_callbacks lists.
+#
+# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
+# that mypy sees the type but the runtime python doesn't.
+_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+
+
+class LoggingTransaction:
+ """An object that almost-transparently proxies for the 'txn' object
+ passed to the constructor. Adds logging and metrics to the .execute()
+ method.
+
+ Args:
+ txn: The database transcation object to wrap.
+ name: The name of this transactions for logging.
+ database_engine
+ after_callbacks: A list that callbacks will be appended to
+ that have been added by `call_after` which should be run on
+ successful completion of the transaction. None indicates that no
+ callbacks should be allowed to be scheduled to run.
+ exception_callbacks: A list that callbacks will be appended
+ to that have been added by `call_on_exception` which should be run
+ if transaction ends with an error. None indicates that no callbacks
+ should be allowed to be scheduled to run.
+ """
+
+ __slots__ = [
+ "txn",
+ "name",
+ "database_engine",
+ "after_callbacks",
+ "exception_callbacks",
+ ]
+
+ def __init__(
+ self,
+ txn: Cursor,
+ name: str,
+ database_engine: BaseDatabaseEngine,
+ after_callbacks: Optional[List[_CallbackListEntry]] = None,
+ exception_callbacks: Optional[List[_CallbackListEntry]] = None,
+ ):
+ self.txn = txn
+ self.name = name
+ self.database_engine = database_engine
+ self.after_callbacks = after_callbacks
+ self.exception_callbacks = exception_callbacks
+
+ def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+ """Call the given callback on the main twisted thread after the
+ transaction has finished. Used to invalidate the caches on the
+ correct thread.
+ """
+ # if self.after_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.after_callbacks is not None
+ self.after_callbacks.append((callback, args, kwargs))
+
+ def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+ # if self.exception_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.exception_callbacks is not None
+ self.exception_callbacks.append((callback, args, kwargs))
+
+ def fetchall(self) -> List[Tuple]:
+ return self.txn.fetchall()
+
+ def fetchone(self) -> Tuple:
+ return self.txn.fetchone()
+
+ def __iter__(self) -> Iterator[Tuple]:
+ return self.txn.__iter__()
+
+ @property
+ def rowcount(self) -> int:
+ return self.txn.rowcount
+
+ @property
+ def description(self) -> Any:
+ return self.txn.description
+
+ def execute_batch(self, sql, args):
+ if isinstance(self.database_engine, PostgresEngine):
+ from psycopg2.extras import execute_batch # type: ignore
+
+ self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+ else:
+ for val in args:
+ self.execute(sql, val)
+
+ def execute(self, sql: str, *args: Any):
+ self._do_execute(self.txn.execute, sql, *args)
+
+ def executemany(self, sql: str, *args: Any):
+ self._do_execute(self.txn.executemany, sql, *args)
+
+ def _make_sql_one_line(self, sql):
+ "Strip newlines out of SQL so that the loggers in the DB are on one line"
+ return " ".join(l.strip() for l in sql.splitlines() if l.strip())
+
+ def _do_execute(self, func, sql, *args):
+ sql = self._make_sql_one_line(sql)
+
+ # TODO(paul): Maybe use 'info' and 'debug' for values?
+ sql_logger.debug("[SQL] {%s} %s", self.name, sql)
+
+ sql = self.database_engine.convert_param_style(sql)
+ if args:
+ try:
+ sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
+ except Exception:
+ # Don't let logging failures stop SQL from working
+ pass
+
+ start = time.time()
+
+ try:
+ return func(sql, *args)
+ except Exception as e:
+ logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+ raise
+ finally:
+ secs = time.time() - start
+ sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+ sql_query_timer.labels(sql.split()[0]).observe(secs)
+
+ def close(self):
+ self.txn.close()
+
+
+class PerformanceCounters(object):
+ def __init__(self):
+ self.current_counters = {}
+ self.previous_counters = {}
+
+ def update(self, key, duration_secs):
+ count, cum_time = self.current_counters.get(key, (0, 0))
+ count += 1
+ cum_time += duration_secs
+ self.current_counters[key] = (count, cum_time)
+
+ def interval(self, interval_duration_secs, limit=3):
+ counters = []
+ for name, (count, cum_time) in iteritems(self.current_counters):
+ prev_count, prev_time = self.previous_counters.get(name, (0, 0))
+ counters.append(
+ (
+ (cum_time - prev_time) / interval_duration_secs,
+ count - prev_count,
+ name,
+ )
+ )
+
+ self.previous_counters = dict(self.current_counters)
+
+ counters.sort(reverse=True)
+
+ top_n_counters = ", ".join(
+ "%s(%d): %.3f%%" % (name, count, 100 * ratio)
+ for ratio, count, name in counters[:limit]
+ )
+
+ return top_n_counters
+
+
+class Database(object):
+ """Wraps a single physical database and connection pool.
+
+ A single database may be used by multiple data stores.
+ """
+
+ _TXN_ID = 0
+
+ def __init__(
+ self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ ):
+ self.hs = hs
+ self._clock = hs.get_clock()
+ self._database_config = database_config
+ self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
+
+ self.updates = BackgroundUpdater(hs, self)
+
+ self._previous_txn_total_time = 0.0
+ self._current_txn_total_time = 0.0
+ self._previous_loop_ts = 0.0
+
+ # TODO(paul): These can eventually be removed once the metrics code
+ # is running in mainline, and we have some nice monitoring frontends
+ # to watch it
+ self._txn_perf_counters = PerformanceCounters()
+
+ self.engine = engine
+
+ # A set of tables that are not safe to use native upserts in.
+ self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+ # We add the user_directory_search table to the blacklist on SQLite
+ # because the existing search table does not have an index, making it
+ # unsafe to use native upserts.
+ if isinstance(self.engine, Sqlite3Engine):
+ self._unsafe_to_upsert_tables.add("user_directory_search")
+
+ if self.engine.can_native_upsert:
+ # Check ASAP (and then later, every 1s) to see if we have finished
+ # background updates of tables that aren't safe to update.
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert,
+ )
+
+ def is_running(self):
+ """Is the database pool currently running
+ """
+ return self._db_pool.running
+
+ @defer.inlineCallbacks
+ def _check_safe_to_upsert(self):
+ """
+ Is it safe to use native UPSERT?
+
+ If there are background updates, we will need to wait, as they may be
+ the addition of indexes that set the UNIQUE constraint that we require.
+
+ If the background updates have not completed, wait 15 sec and check again.
+ """
+ updates = yield self.simple_select_list(
+ "background_updates",
+ keyvalues=None,
+ retcols=["update_name"],
+ desc="check_background_updates",
+ )
+ updates = [x["update_name"] for x in updates]
+
+ for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+ if update_name not in updates:
+ logger.debug("Now safe to upsert in %s", table)
+ self._unsafe_to_upsert_tables.discard(table)
+
+ # If there's any updates still running, reschedule to run.
+ if updates:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert,
+ )
+
+ def start_profiling(self):
+ self._previous_loop_ts = monotonic_time()
+
+ def loop():
+ curr = self._current_txn_total_time
+ prev = self._previous_txn_total_time
+ self._previous_txn_total_time = curr
+
+ time_now = monotonic_time()
+ time_then = self._previous_loop_ts
+ self._previous_loop_ts = time_now
+
+ duration = time_now - time_then
+ ratio = (curr - prev) / duration
+
+ top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
+
+ perf_logger.debug(
+ "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
+ )
+
+ self._clock.looping_call(loop, 10000)
+
+ def new_transaction(
+ self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
+ ):
+ start = monotonic_time()
+ txn_id = self._TXN_ID
+
+ # We don't really need these to be unique, so lets stop it from
+ # growing really large.
+ self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
+
+ name = "%s-%x" % (desc, txn_id)
+
+ transaction_logger.debug("[TXN START] {%s}", name)
+
+ try:
+ i = 0
+ N = 5
+ while True:
+ cursor = LoggingTransaction(
+ conn.cursor(),
+ name,
+ self.engine,
+ after_callbacks,
+ exception_callbacks,
+ )
+ try:
+ r = func(cursor, *args, **kwargs)
+ conn.commit()
+ return r
+ except self.engine.module.OperationalError as e:
+ # This can happen if the database disappears mid
+ # transaction.
+ logger.warning(
+ "[TXN OPERROR] {%s} %s %d/%d",
+ name,
+ exception_to_unicode(e),
+ i,
+ N,
+ )
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.engine.module.Error as e1:
+ logger.warning(
+ "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
+ )
+ continue
+ raise
+ except self.engine.module.DatabaseError as e:
+ if self.engine.is_deadlock(e):
+ logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.engine.module.Error as e1:
+ logger.warning(
+ "[TXN EROLL] {%s} %s",
+ name,
+ exception_to_unicode(e1),
+ )
+ continue
+ raise
+ finally:
+ # we're either about to retry with a new cursor, or we're about to
+ # release the connection. Once we release the connection, it could
+ # get used for another query, which might do a conn.rollback().
+ #
+ # In the latter case, even though that probably wouldn't affect the
+ # results of this transaction, python's sqlite will reset all
+ # statements on the connection [1], which will make our cursor
+ # invalid [2].
+ #
+ # In any case, continuing to read rows after commit()ing seems
+ # dubious from the PoV of ACID transactional semantics
+ # (sqlite explicitly says that once you commit, you may see rows
+ # from subsequent updates.)
+ #
+ # In psycopg2, cursors are essentially a client-side fabrication -
+ # all the data is transferred to the client side when the statement
+ # finishes executing - so in theory we could go on streaming results
+ # from the cursor, but attempting to do so would make us
+ # incompatible with sqlite, so let's make sure we're not doing that
+ # by closing the cursor.
+ #
+ # (*named* cursors in psycopg2 are different and are proper server-
+ # side things, but (a) we don't use them and (b) they are implicitly
+ # closed by ending the transaction anyway.)
+ #
+ # In short, if we haven't finished with the cursor yet, that's a
+ # problem waiting to bite us.
+ #
+ # TL;DR: we're done with the cursor, so we can close it.
+ #
+ # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
+ # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
+ cursor.close()
+ except Exception as e:
+ logger.debug("[TXN FAIL] {%s} %s", name, e)
+ raise
+ finally:
+ end = monotonic_time()
+ duration = end - start
+
+ LoggingContext.current_context().add_database_transaction(duration)
+
+ transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
+
+ self._current_txn_total_time += duration
+ self._txn_perf_counters.update(desc, duration)
+ sql_txn_timer.labels(desc).observe(duration)
+
+ @defer.inlineCallbacks
+ def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+ """Starts a transaction on the database and runs a given function
+
+ Arguments:
+ desc: description of the transaction, for logging and metrics
+ func: callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
+
+ args: positional args to pass to `func`
+ kwargs: named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ after_callbacks = [] # type: List[_CallbackListEntry]
+ exception_callbacks = [] # type: List[_CallbackListEntry]
+
+ if LoggingContext.current_context() == LoggingContext.sentinel:
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
+
+ try:
+ result = yield self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ **kwargs
+ )
+
+ for after_callback, after_args, after_kwargs in after_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ except: # noqa: E722, as we reraise the exception this is fine.
+ for after_callback, after_args, after_kwargs in exception_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ raise
+
+ return result
+
+ @defer.inlineCallbacks
+ def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+ """Wraps the .runWithConnection() method on the underlying db_pool.
+
+ Arguments:
+ func: callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args: positional args to pass to `func`
+ kwargs: named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ parent_context = (
+ LoggingContext.current_context()
+ ) # type: Optional[LoggingContextOrSentinel]
+ if parent_context == LoggingContext.sentinel:
+ logger.warning(
+ "Starting db connection from sentinel context: metrics will be lost"
+ )
+ parent_context = None
+
+ start_time = monotonic_time()
+
+ def inner_func(conn, *args, **kwargs):
+ with LoggingContext("runWithConnection", parent_context) as context:
+ sched_duration_sec = monotonic_time() - start_time
+ sql_scheduling_timer.observe(sched_duration_sec)
+ context.add_database_scheduled(sched_duration_sec)
+
+ if self.engine.is_connection_closed(conn):
+ logger.debug("Reconnecting closed database connection")
+ conn.reconnect()
+
+ return func(conn, *args, **kwargs)
+
+ result = yield make_deferred_yieldable(
+ self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+ )
+
+ return result
+
+ @staticmethod
+ def cursor_to_dict(cursor):
+ """Converts a SQL cursor into an list of dicts.
+
+ Args:
+ cursor : The DBAPI cursor which has executed a query.
+ Returns:
+ A list of dicts where the key is the column header.
+ """
+ col_headers = [intern(str(column[0])) for column in cursor.description]
+ results = [dict(zip(col_headers, row)) for row in cursor]
+ return results
+
+ def execute(self, desc, decoder, query, *args):
+ """Runs a single query for a result set.
+
+ Args:
+ decoder - The function which can resolve the cursor results to
+ something meaningful.
+ query - The query string to execute
+ *args - Query args.
+ Returns:
+ The result of decoder(results)
+ """
+
+ def interaction(txn):
+ txn.execute(query, args)
+ if decoder:
+ return decoder(txn)
+ else:
+ return txn.fetchall()
+
+ return self.runInteraction(desc, interaction)
+
+ # "Simple" SQL API methods that operate on a single table with no JOINs,
+ # no complex WHERE clauses, just a dict of values for columns.
+
+ @defer.inlineCallbacks
+ def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+ """Executes an INSERT query on the named table.
+
+ Args:
+ table : string giving the table name
+ values : dict of new column names and values for them
+ or_ignore : bool stating whether an exception should be raised
+ when a conflicting row already exists. If True, False will be
+ returned by the function instead
+ desc : string giving a description of the transaction
+
+ Returns:
+ bool: Whether the row was inserted or not. Only useful when
+ `or_ignore` is True
+ """
+ try:
+ yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+ except self.engine.module.IntegrityError:
+ # We have to do or_ignore flag at this layer, since we can't reuse
+ # a cursor after we receive an error from the db.
+ if not or_ignore:
+ raise
+ return False
+ return True
+
+ @staticmethod
+ def simple_insert_txn(txn, table, values):
+ keys, vals = zip(*values.items())
+
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in keys),
+ ", ".join("?" for _ in keys),
+ )
+
+ txn.execute(sql, vals)
+
+ def simple_insert_many(self, table, values, desc):
+ return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+
+ @staticmethod
+ def simple_insert_many_txn(txn, table, values):
+ if not values:
+ return
+
+ # This is a *slight* abomination to get a list of tuples of key names
+ # and a list of tuples of value names.
+ #
+ # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+ # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+ #
+ # The sort is to ensure that we don't rely on dictionary iteration
+ # order.
+ keys, vals = zip(
+ *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+ )
+
+ for k in keys:
+ if k != keys[0]:
+ raise RuntimeError("All items must have the same keys")
+
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in keys[0]),
+ ", ".join("?" for _ in keys[0]),
+ )
+
+ txn.executemany(sql, vals)
+
+ @defer.inlineCallbacks
+ def simple_upsert(
+ self,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ desc="simple_upsert",
+ lock=True,
+ ):
+ """
+
+ `lock` should generally be set to True (the default), but can be set
+ to False if either of the following are true:
+
+ * there is a UNIQUE INDEX on the key columns. In this case a conflict
+ will cause an IntegrityError in which case this function will retry
+ the update.
+
+ * we somehow know that we are the only thread which will be updating
+ this table.
+
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key columns and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ Deferred(None or bool): Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ attempts = 0
+ while True:
+ try:
+ result = yield self.runInteraction(
+ desc,
+ self.simple_upsert_txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values,
+ lock=lock,
+ )
+ return result
+ except self.engine.module.IntegrityError as e:
+ attempts += 1
+ if attempts >= 5:
+ # don't retry forever, because things other than races
+ # can cause IntegrityErrors
+ raise
+
+ # presumably we raced with another transaction: let's retry.
+ logger.warning(
+ "IntegrityError when upserting into %s; retrying: %s", table, e
+ )
+
+ def simple_upsert_txn(
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
+ ):
+ """
+ Pick the UPSERT method which works best on the platform. Either the
+ native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+ Args:
+ txn: The transaction to use.
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ None or bool: Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+ return self.simple_upsert_txn_native_upsert(
+ txn, table, keyvalues, values, insertion_values=insertion_values
+ )
+ else:
+ return self.simple_upsert_txn_emulated(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ lock=lock,
+ )
+
+ def simple_upsert_txn_emulated(
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
+ ):
+ """
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ bool: Return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ # We need to lock the table :(, unless we're *really* careful
+ if lock:
+ self.engine.lock_table(txn, table)
+
+ def _getwhere(key):
+ # If the value we're passing in is None (aka NULL), we need to use
+ # IS, not =, as NULL = NULL equals NULL (False).
+ if keyvalues[key] is None:
+ return "%s IS ?" % (key,)
+ else:
+ return "%s = ?" % (key,)
+
+ if not values:
+ # If `values` is empty, then all of the values we care about are in
+ # the unique key, so there is nothing to UPDATE. We can just do a
+ # SELECT instead to see if it exists.
+ sql = "SELECT 1 FROM %s WHERE %s" % (
+ table,
+ " AND ".join(_getwhere(k) for k in keyvalues),
+ )
+ sqlargs = list(keyvalues.values())
+ txn.execute(sql, sqlargs)
+ if txn.fetchall():
+ # We have an existing record.
+ return False
+ else:
+ # First try to update.
+ sql = "UPDATE %s SET %s WHERE %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in values),
+ " AND ".join(_getwhere(k) for k in keyvalues),
+ )
+ sqlargs = list(values.values()) + list(keyvalues.values())
+
+ txn.execute(sql, sqlargs)
+ if txn.rowcount > 0:
+ # successfully updated at least one row.
+ return False
+
+ # We didn't find any existing rows, so insert a new one
+ allvalues = {} # type: Dict[str, Any]
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+ allvalues.update(insertion_values)
+
+ sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues),
+ )
+ txn.execute(sql, list(allvalues.values()))
+ # successfully inserted
+ return True
+
+ def simple_upsert_txn_native_upsert(
+ self, txn, table, keyvalues, values, insertion_values={}
+ ):
+ """
+ Use the native UPSERT functionality in recent PostgreSQL versions.
+
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ Returns:
+ None
+ """
+ allvalues = {} # type: Dict[str, Any]
+ allvalues.update(keyvalues)
+ allvalues.update(insertion_values)
+
+ if not values:
+ latter = "NOTHING"
+ else:
+ allvalues.update(values)
+ latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
+
+ sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues),
+ ", ".join(k for k in keyvalues),
+ latter,
+ )
+ txn.execute(sql, list(allvalues.values()))
+
+ def simple_upsert_many_txn(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+ return self.simple_upsert_many_txn_native_upsert(
+ txn, table, key_names, key_values, value_names, value_values
+ )
+ else:
+ return self.simple_upsert_many_txn_emulated(
+ txn, table, key_names, key_values, value_names, value_values
+ )
+
+ def simple_upsert_many_txn_emulated(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times, but without native UPSERT support or batching.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ # No value columns, therefore make a blank list so that the following
+ # zip() works correctly.
+ if not value_names:
+ value_values = [() for x in range(len(key_values))]
+
+ for keyv, valv in zip(key_values, value_values):
+ _keys = {x: y for x, y in zip(key_names, keyv)}
+ _vals = {x: y for x, y in zip(value_names, valv)}
+
+ self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
+
+ def simple_upsert_many_txn_native_upsert(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times, using batching where possible.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ allnames = [] # type: List[str]
+ allnames.extend(key_names)
+ allnames.extend(value_names)
+
+ if not value_names:
+ # No value columns, therefore make a blank list so that the
+ # following zip() works correctly.
+ latter = "NOTHING"
+ value_values = [() for x in range(len(key_values))]
+ else:
+ latter = "UPDATE SET " + ", ".join(
+ k + "=EXCLUDED." + k for k in value_names
+ )
+
+ sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+ table,
+ ", ".join(k for k in allnames),
+ ", ".join("?" for _ in allnames),
+ ", ".join(key_names),
+ latter,
+ )
+
+ args = []
+
+ for x, y in zip(key_values, value_values):
+ args.append(tuple(x) + tuple(y))
+
+ return txn.execute_batch(sql, args)
+
+ def simple_select_one(
+ self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
+ ):
+ """Executes a SELECT query on the named table, which is expected to
+ return a single row, returning multiple columns from it.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ retcols : list of strings giving the names of the columns to return
+
+ allow_none : If true, return None instead of failing if the SELECT
+ statement returns no rows
+ """
+ return self.runInteraction(
+ desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
+ )
+
+ def simple_select_one_onecol(
+ self,
+ table,
+ keyvalues,
+ retcol,
+ allow_none=False,
+ desc="simple_select_one_onecol",
+ ):
+ """Executes a SELECT query on the named table, which is expected to
+ return a single row, returning a single column from it.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ retcol : string giving the name of the column to return
+ """
+ return self.runInteraction(
+ desc,
+ self.simple_select_one_onecol_txn,
+ table,
+ keyvalues,
+ retcol,
+ allow_none=allow_none,
+ )
+
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls, txn, table, keyvalues, retcol, allow_none=False
+ ):
+ ret = cls.simple_select_onecol_txn(
+ txn, table=table, keyvalues=keyvalues, retcol=retcol
+ )
+
+ if ret:
+ return ret[0]
+ else:
+ if allow_none:
+ return None
+ else:
+ raise StoreError(404, "No row found")
+
+ @staticmethod
+ def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
+
+ if keyvalues:
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ txn.execute(sql, list(keyvalues.values()))
+ else:
+ txn.execute(sql)
+
+ return [r[0] for r in txn]
+
+ def simple_select_onecol(
+ self, table, keyvalues, retcol, desc="simple_select_onecol"
+ ):
+ """Executes a SELECT query on the named table, which returns a list
+ comprising of the values of the named column from the selected rows.
+
+ Args:
+ table (str): table name
+ keyvalues (dict|None): column names and values to select the rows with
+ retcol (str): column whos value we wish to retrieve.
+
+ Returns:
+ Deferred: Results in a list
+ """
+ return self.runInteraction(
+ desc, self.simple_select_onecol_txn, table, keyvalues, retcol
+ )
+
+ def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ keyvalues (dict[str, Any] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.runInteraction(
+ desc, self.simple_select_list_txn, table, keyvalues, retcols
+ )
+
+ @classmethod
+ def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
+ """
+ if keyvalues:
+ sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+ txn.execute(sql, list(keyvalues.values()))
+ else:
+ sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+ txn.execute(sql)
+
+ return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def simple_select_many_batch(
+ self,
+ table,
+ column,
+ iterable,
+ retcols,
+ keyvalues={},
+ desc="simple_select_many_batch",
+ batch_size=100,
+ ):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ results = [] # type: List[Dict[str, Any]]
+
+ if not iterable:
+ return results
+
+ # iterables can not be sliced, so convert it to a list first
+ it_list = list(iterable)
+
+ chunks = [
+ it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
+ ]
+ for chunk in chunks:
+ rows = yield self.runInteraction(
+ desc,
+ self.simple_select_many_txn,
+ table,
+ column,
+ chunk,
+ keyvalues,
+ retcols,
+ )
+
+ results.extend(rows)
+
+ return results
+
+ @classmethod
+ def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ if not iterable:
+ return []
+
+ clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+ clauses = [clause]
+
+ for key, value in iteritems(keyvalues):
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join(clauses),
+ )
+
+ txn.execute(sql, values)
+ return cls.cursor_to_dict(txn)
+
+ def simple_update(self, table, keyvalues, updatevalues, desc):
+ return self.runInteraction(
+ desc, self.simple_update_txn, table, keyvalues, updatevalues
+ )
+
+ @staticmethod
+ def simple_update_txn(txn, table, keyvalues, updatevalues):
+ if keyvalues:
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ else:
+ where = ""
+
+ update_sql = "UPDATE %s SET %s %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ where,
+ )
+
+ txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
+
+ return txn.rowcount
+
+ def simple_update_one(
+ self, table, keyvalues, updatevalues, desc="simple_update_one"
+ ):
+ """Executes an UPDATE query on the named table, setting new values for
+ columns in a row matching the key values.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ updatevalues : dict giving column names and values to update
+ retcols : optional list of column names to return
+
+ If present, retcols gives a list of column names on which to perform
+ a SELECT statement *before* performing the UPDATE statement. The values
+ of these will be returned in a dict.
+
+ These are performed within the same transaction, allowing an atomic
+ get-and-set. This can be used to implement compare-and-set by putting
+ the update column in the 'keyvalues' dict as well.
+ """
+ return self.runInteraction(
+ desc, self.simple_update_one_txn, table, keyvalues, updatevalues
+ )
+
+ @classmethod
+ def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
+
+ if rowcount == 0:
+ raise StoreError(404, "No row found (%s)" % (table,))
+ if rowcount > 1:
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+ @staticmethod
+ def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+ select_sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+
+ txn.execute(select_sql, list(keyvalues.values()))
+ row = txn.fetchone()
+
+ if not row:
+ if allow_none:
+ return None
+ raise StoreError(404, "No row found (%s)" % (table,))
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+ return dict(zip(retcols, row))
+
+ def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+ """Executes a DELETE query on the named table, expecting to delete a
+ single row.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+ return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+
+ @staticmethod
+ def simple_delete_one_txn(txn, table, keyvalues):
+ """Executes a DELETE query on the named table, expecting to delete a
+ single row.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+
+ txn.execute(sql, list(keyvalues.values()))
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found (%s)" % (table,))
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+ def simple_delete(self, table, keyvalues, desc):
+ return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+
+ @staticmethod
+ def simple_delete_txn(txn, table, keyvalues):
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+
+ txn.execute(sql, list(keyvalues.values()))
+ return txn.rowcount
+
+ def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+ return self.runInteraction(
+ desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
+ )
+
+ @staticmethod
+ def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+
+ Returns:
+ int: Number rows deleted
+ """
+ if not iterable:
+ return 0
+
+ sql = "DELETE FROM %s" % table
+
+ clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+ clauses = [clause]
+
+ for key, value in iteritems(keyvalues):
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ if clauses:
+ sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+ txn.execute(sql, values)
+
+ return txn.rowcount
+
+ def get_cache_dict(
+ self, db_conn, table, entity_column, stream_column, max_value, limit=100000
+ ):
+ # Fetch a mapping of room_id -> max stream position for "recent" rooms.
+ # It doesn't really matter how many we get, the StreamChangeCache will
+ # do the right thing to ensure it respects the max size of cache.
+ sql = (
+ "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
+ " WHERE %(stream)s > ? - %(limit)s"
+ " GROUP BY %(entity)s"
+ ) % {
+ "table": table,
+ "entity": entity_column,
+ "stream": stream_column,
+ "limit": limit,
+ }
+
+ sql = self.engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (int(max_value),))
+
+ cache = {row[0]: int(row[1]) for row in txn}
+
+ txn.close()
+
+ if cache:
+ min_val = min(itervalues(cache))
+ else:
+ min_val = max_value
+
+ return cache, min_val
+
+ def simple_select_list_paginate(
+ self,
+ table,
+ orderby,
+ start,
+ limit,
+ retcols,
+ filters=None,
+ keyvalues=None,
+ order_direction="ASC",
+ desc="simple_select_list_paginate",
+ ):
+ """
+ Executes a SELECT query on the named table with start and limit,
+ of row numbers, which may return zero or number of rows from start to limit,
+ returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ filters (dict[str, T] | None):
+ column names and values to filter the rows with, or None to not
+ apply a WHERE ? LIKE ? clause.
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ orderby (str): Column to order the results by.
+ start (int): Index to begin the query at.
+ limit (int): Number of results to return.
+ retcols (iterable[str]): the names of the columns to return
+ order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.runInteraction(
+ desc,
+ self.simple_select_list_paginate_txn,
+ table,
+ orderby,
+ start,
+ limit,
+ retcols,
+ filters=filters,
+ keyvalues=keyvalues,
+ order_direction=order_direction,
+ )
+
+ @classmethod
+ def simple_select_list_paginate_txn(
+ cls,
+ txn,
+ table,
+ orderby,
+ start,
+ limit,
+ retcols,
+ filters=None,
+ keyvalues=None,
+ order_direction="ASC",
+ ):
+ """
+ Executes a SELECT query on the named table with start and limit,
+ of row numbers, which may return zero or number of rows from start to limit,
+ returning the result as a list of dicts.
+
+ Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
+ select attributes with exact matches. All constraints are joined together
+ using 'AND'.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ orderby (str): Column to order the results by.
+ start (int): Index to begin the query at.
+ limit (int): Number of results to return.
+ retcols (iterable[str]): the names of the columns to return
+ filters (dict[str, T] | None):
+ column names and values to filter the rows with, or None to not
+ apply a WHERE ? LIKE ? clause.
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ if order_direction not in ["ASC", "DESC"]:
+ raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+
+ where_clause = "WHERE " if filters or keyvalues else ""
+ arg_list = [] # type: List[Any]
+ if filters:
+ where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
+ arg_list += list(filters.values())
+ where_clause += " AND " if filters and keyvalues else ""
+ if keyvalues:
+ where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ arg_list += list(keyvalues.values())
+
+ sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
+ ", ".join(retcols),
+ table,
+ where_clause,
+ orderby,
+ order_direction,
+ )
+ txn.execute(sql, arg_list + [limit, start])
+
+ return cls.cursor_to_dict(txn)
+
+ def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ term (str | None):
+ term for searching the table matched to a column.
+ col (str): column to query term should be matched to
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]] or None
+ """
+
+ return self.runInteraction(
+ desc, self.simple_search_list_txn, table, term, col, retcols
+ )
+
+ @classmethod
+ def simple_search_list_txn(cls, txn, table, term, col, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ term (str | None):
+ term for searching the table matched to a column.
+ col (str): column to query term should be matched to
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]] or None
+ """
+ if term:
+ sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
+ termvalues = ["%%" + term + "%%"]
+ txn.execute(sql, termvalues)
+ else:
+ return 0
+
+ return cls.cursor_to_dict(txn)
+
+
+def make_in_list_sql_clause(
+ database_engine, column: str, iterable: Iterable
+) -> Tuple[str, list]:
+ """Returns an SQL clause that checks the given column is in the iterable.
+
+ On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
+ it expands to `column = ANY(?)`. While both DBs support the `IN` form,
+ using the `ANY` form on postgres means that it views queries with
+ different length iterables as the same, helping the query stats.
+
+ Args:
+ database_engine
+ column: Name of the column
+ iterable: The values to check the column against.
+
+ Returns:
+ A tuple of SQL query and the args
+ """
+
+ if database_engine.supports_using_any_list:
+ # This should hopefully be faster, but also makes postgres query
+ # stats easier to understand.
+ return "%s = ANY(?)" % (column,), [list(iterable)]
+ else:
+ return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
deleted file mode 100644
index 2fabb9e2cb..0000000000
--- a/synapse/storage/end_to_end_keys.py
+++ /dev/null
@@ -1,283 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from six import iteritems
-
-from canonicaljson import encode_canonical_json
-
-from twisted.internet import defer
-
-from synapse.util.caches.descriptors import cached
-
-from ._base import SQLBaseStore, db_to_json
-
-
-class EndToEndKeyWorkerStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_e2e_device_keys(
- self, query_list, include_all_devices=False, include_deleted_devices=False
- ):
- """Fetch a list of device keys.
- Args:
- query_list(list): List of pairs of user_ids and device_ids.
- include_all_devices (bool): whether to include entries for devices
- that don't have device keys
- include_deleted_devices (bool): whether to include null entries for
- devices which no longer exist (but were in the query_list).
- This option only takes effect if include_all_devices is true.
- Returns:
- Dict mapping from user-id to dict mapping from device_id to
- dict containing "key_json", "device_display_name".
- """
- if not query_list:
- defer.returnValue({})
-
- results = yield self.runInteraction(
- "get_e2e_device_keys",
- self._get_e2e_device_keys_txn,
- query_list,
- include_all_devices,
- include_deleted_devices,
- )
-
- for user_id, device_keys in iteritems(results):
- for device_id, device_info in iteritems(device_keys):
- device_info["keys"] = db_to_json(device_info.pop("key_json"))
-
- defer.returnValue(results)
-
- def _get_e2e_device_keys_txn(
- self, txn, query_list, include_all_devices=False, include_deleted_devices=False
- ):
- query_clauses = []
- query_params = []
-
- if include_all_devices is False:
- include_deleted_devices = False
-
- if include_deleted_devices:
- deleted_devices = set(query_list)
-
- for (user_id, device_id) in query_list:
- query_clause = "user_id = ?"
- query_params.append(user_id)
-
- if device_id is not None:
- query_clause += " AND device_id = ?"
- query_params.append(device_id)
-
- query_clauses.append(query_clause)
-
- sql = (
- "SELECT user_id, device_id, "
- " d.display_name AS device_display_name, "
- " k.key_json"
- " FROM devices d"
- " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
- " WHERE %s"
- ) % (
- "LEFT" if include_all_devices else "INNER",
- " OR ".join("(" + q + ")" for q in query_clauses),
- )
-
- txn.execute(sql, query_params)
- rows = self.cursor_to_dict(txn)
-
- result = {}
- for row in rows:
- if include_deleted_devices:
- deleted_devices.remove((row["user_id"], row["device_id"]))
- result.setdefault(row["user_id"], {})[row["device_id"]] = row
-
- if include_deleted_devices:
- for user_id, device_id in deleted_devices:
- result.setdefault(user_id, {})[device_id] = None
-
- return result
-
- @defer.inlineCallbacks
- def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
- """Retrieve a number of one-time keys for a user
-
- Args:
- user_id(str): id of user to get keys for
- device_id(str): id of device to get keys for
- key_ids(list[str]): list of key ids (excluding algorithm) to
- retrieve
-
- Returns:
- deferred resolving to Dict[(str, str), str]: map from (algorithm,
- key_id) to json string for key
- """
-
- rows = yield self._simple_select_many_batch(
- table="e2e_one_time_keys_json",
- column="key_id",
- iterable=key_ids,
- retcols=("algorithm", "key_id", "key_json"),
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="add_e2e_one_time_keys_check",
- )
-
- defer.returnValue(
- {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
- )
-
- @defer.inlineCallbacks
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
- """Insert some new one time keys for a device. Errors if any of the
- keys already exist.
-
- Args:
- user_id(str): id of user to get keys for
- device_id(str): id of device to get keys for
- time_now(long): insertion time to record (ms since epoch)
- new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
- (algorithm, key_id, key json)
- """
-
- def _add_e2e_one_time_keys(txn):
- # We are protected from race between lookup and insertion due to
- # a unique constraint. If there is a race of two calls to
- # `add_e2e_one_time_keys` then they'll conflict and we will only
- # insert one set.
- self._simple_insert_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- values=[
- {
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- "key_id": key_id,
- "ts_added_ms": time_now,
- "key_json": json_bytes,
- }
- for algorithm, key_id, json_bytes in new_keys
- ],
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- yield self.runInteraction(
- "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
- )
-
- @cached(max_entries=10000)
- def count_e2e_one_time_keys(self, user_id, device_id):
- """ Count the number of one time keys the server has for a device
- Returns:
- Dict mapping from algorithm to number of keys for that algorithm.
- """
-
- def _count_e2e_one_time_keys(txn):
- sql = (
- "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ?"
- " GROUP BY algorithm"
- )
- txn.execute(sql, (user_id, device_id))
- result = {}
- for algorithm, key_count in txn:
- result[algorithm] = key_count
- return result
-
- return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
-
-
-class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
- def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
- """Stores device keys for a device. Returns whether there was a change
- or the keys were already in the database.
- """
-
- def _set_e2e_device_keys_txn(txn):
- old_key_json = self._simple_select_one_onecol_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- retcol="key_json",
- allow_none=True,
- )
-
- # In py3 we need old_key_json to match new_key_json type. The DB
- # returns unicode while encode_canonical_json returns bytes.
- new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
- if old_key_json == new_key_json:
- return False
-
- self._simple_upsert_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- values={"ts_added_ms": time_now, "key_json": new_key_json},
- )
-
- return True
-
- return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
-
- def claim_e2e_one_time_keys(self, query_list):
- """Take a list of one time keys out of the database"""
-
- def _claim_e2e_one_time_keys(txn):
- sql = (
- "SELECT key_id, key_json FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
- " LIMIT 1"
- )
- result = {}
- delete = []
- for user_id, device_id, algorithm in query_list:
- user_result = result.setdefault(user_id, {})
- device_result = user_result.setdefault(device_id, {})
- txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn:
- device_result[algorithm + ":" + key_id] = key_json
- delete.append((user_id, device_id, algorithm, key_id))
- sql = (
- "DELETE FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
- " AND key_id = ?"
- )
- for user_id, device_id, algorithm, key_id in delete:
- txn.execute(sql, (user_id, device_id, algorithm, key_id))
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
- return result
-
- return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
-
- def delete_e2e_keys_by_device(self, user_id, device_id):
- def delete_e2e_keys_by_device_txn(txn):
- self._simple_delete_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- )
- self._simple_delete_txn(
- txn,
- table="e2e_one_time_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return self.runInteraction(
- "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
- )
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9d2d519922..035f9ea6e9 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,29 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import importlib
import platform
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
-SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
-
-def create_engine(database_config):
+def create_engine(database_config) -> BaseDatabaseEngine:
name = database_config["name"]
- engine_class = SUPPORTED_MODULE.get(name, None)
- if engine_class:
+ if name == "sqlite3":
+ import sqlite3
+
+ return Sqlite3Engine(sqlite3, database_config)
+
+ if name == "psycopg2":
# pypy requires psycopg2cffi rather than psycopg2
- if name == "psycopg2" and platform.python_implementation() == "PyPy":
- name = "psycopg2cffi"
- module = importlib.import_module(name)
- return engine_class(module, database_config)
+ if platform.python_implementation() == "PyPy":
+ import psycopg2cffi as psycopg2 # type: ignore
+ else:
+ import psycopg2 # type: ignore
+
+ return PostgresEngine(psycopg2, database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,))
-__all__ = ["create_engine", "IncorrectDatabaseSetup"]
+__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ec5a4d198b..ab0bbe4bd3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -12,7 +12,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
+from typing import Generic, TypeVar
+
+from synapse.storage.types import Connection
class IncorrectDatabaseSetup(RuntimeError):
pass
+
+
+ConnectionType = TypeVar("ConnectionType", bound=Connection)
+
+
+class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+ def __init__(self, module, database_config: dict):
+ self.module = module
+
+ @property
+ @abc.abstractmethod
+ def single_threaded(self) -> bool:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def can_native_upsert(self) -> bool:
+ """
+ Do we support native UPSERTs?
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def supports_tuple_comparison(self) -> bool:
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def supports_using_any_list(self) -> bool:
+ """
+ Do we support using `a = ANY(?)` and passing a list
+ """
+ ...
+
+ @abc.abstractmethod
+ def check_database(
+ self, db_conn: ConnectionType, allow_outdated_version: bool = False
+ ) -> None:
+ ...
+
+ @abc.abstractmethod
+ def check_new_database(self, txn) -> None:
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
+ ...
+
+ @abc.abstractmethod
+ def convert_param_style(self, sql: str) -> str:
+ ...
+
+ @abc.abstractmethod
+ def on_new_connection(self, db_conn: ConnectionType) -> None:
+ ...
+
+ @abc.abstractmethod
+ def is_deadlock(self, error: Exception) -> bool:
+ ...
+
+ @abc.abstractmethod
+ def is_connection_closed(self, conn: ConnectionType) -> bool:
+ ...
+
+ @abc.abstractmethod
+ def lock_table(self, txn, table: str) -> None:
+ ...
+
+ @abc.abstractmethod
+ def get_next_state_group_id(self, txn) -> int:
+ """Returns an int that can be used as a new state_group ID
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def server_version(self) -> str:
+ """Gets a string giving the server version. For example: '3.22.0'
+ """
+ ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 1b97ee74e3..6c7d08a6f2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,38 +13,97 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import IncorrectDatabaseSetup
+import logging
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
-class PostgresEngine(object):
- single_threaded = False
+logger = logging.getLogger(__name__)
+
+class PostgresEngine(BaseDatabaseEngine):
def __init__(self, database_module, database_config):
- self.module = database_module
+ super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE)
- self.synchronous_commit = database_config.get("synchronous_commit", True)
- self._version = None # unknown as yet
- def check_database(self, txn):
- txn.execute("SHOW SERVER_ENCODING")
- rows = txn.fetchall()
- if rows and rows[0][0] != "UTF8":
- raise IncorrectDatabaseSetup(
- "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
- "See docs/postgres.rst for more information." % (rows[0][0],)
- )
+ # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
+ # actually want to use bytes than wrap it in `bytearray`.
+ def _disable_bytes_adapter(_):
+ raise Exception("Passing bytes to DB is disabled.")
- def convert_param_style(self, sql):
- return sql.replace("?", "%s")
+ self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
+ self.synchronous_commit = database_config.get("synchronous_commit", True)
+ self._version = None # unknown as yet
- def on_new_connection(self, db_conn):
+ @property
+ def single_threaded(self) -> bool:
+ return False
+ def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
+ # Are we on a supported PostgreSQL version?
+ if not allow_outdated_version and self._version < 90500:
+ raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
+
+ with db_conn.cursor() as txn:
+ txn.execute("SHOW SERVER_ENCODING")
+ rows = txn.fetchall()
+ if rows and rows[0][0] != "UTF8":
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
+ "See docs/postgres.md for more information." % (rows[0][0],)
+ )
+
+ txn.execute(
+ "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+ )
+ collation, ctype = txn.fetchone()
+ if collation != "C":
+ logger.warning(
+ "Database has incorrect collation of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information.",
+ collation,
+ )
+
+ if ctype != "C":
+ logger.warning(
+ "Database has incorrect ctype of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information.",
+ ctype,
+ )
+
+ def check_new_database(self, txn):
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
+
+ txn.execute(
+ "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+ )
+ collation, ctype = txn.fetchone()
+
+ errors = []
+
+ if collation != "C":
+ errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
+
+ if ctype != "C":
+ errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+
+ if errors:
+ raise IncorrectDatabaseSetup(
+ "Database is incorrectly configured:\n\n%s\n\n"
+ "See docs/postgres.md for more information." % ("\n".join(errors))
+ )
+
+ def convert_param_style(self, sql):
+ return sql.replace("?", "%s")
+
+ def on_new_connection(self, db_conn):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
@@ -64,9 +123,22 @@ class PostgresEngine(object):
@property
def can_native_upsert(self):
"""
- Can we use native UPSERTs? This requires PostgreSQL 9.5+.
+ Can we use native UPSERTs?
+ """
+ return True
+
+ @property
+ def supports_tuple_comparison(self):
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+ """
+ return True
+
+ @property
+ def supports_using_any_list(self):
+ """Do we support using `a = ANY(?)` and passing a list
"""
- return self._version >= 90500
+ return True
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
@@ -95,8 +167,8 @@ class PostgresEngine(object):
Returns:
string
"""
- # note that this is a bit of a hack because it relies on on_new_connection
- # having been called at least once. Still, that should be a safe bet here.
+ # note that this is a bit of a hack because it relies on check_database
+ # having been called. Still, that should be a safe bet here.
numver = self._version
assert numver is not None
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 933bcf42c2..3bc2e8b986 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,18 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import struct
import threading
+import typing
-from synapse.storage.prepare_database import prepare_database
+from synapse.storage.engines import BaseDatabaseEngine
+if typing.TYPE_CHECKING:
+ import sqlite3 # noqa: F401
-class Sqlite3Engine(object):
- single_threaded = True
+class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def __init__(self, database_module, database_config):
- self.module = database_module
+ super().__init__(database_module, database_config)
+
+ database = database_config.get("args", {}).get("database")
+ self._is_in_memory = database in (None, ":memory:",)
# The current max state_group, or None if we haven't looked
# in the DB yet.
@@ -31,6 +35,10 @@ class Sqlite3Engine(object):
self._current_state_group_id_lock = threading.Lock()
@property
+ def single_threaded(self) -> bool:
+ return True
+
+ @property
def can_native_upsert(self):
"""
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
@@ -38,14 +46,44 @@ class Sqlite3Engine(object):
"""
return self.module.sqlite_version_info >= (3, 24, 0)
- def check_database(self, txn):
- pass
+ @property
+ def supports_tuple_comparison(self):
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires
+ SQLite 3.15+.
+ """
+ return self.module.sqlite_version_info >= (3, 15, 0)
+
+ @property
+ def supports_using_any_list(self):
+ """Do we support using `a = ANY(?)` and passing a list
+ """
+ return False
+
+ def check_database(self, db_conn, allow_outdated_version: bool = False):
+ if not allow_outdated_version:
+ version = self.module.sqlite_version_info
+ if version < (3, 11, 0):
+ raise RuntimeError("Synapse requires sqlite 3.11 or above.")
+
+ def check_new_database(self, txn):
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
def convert_param_style(self, sql):
return sql
def on_new_connection(self, db_conn):
- prepare_database(db_conn, self, config=None)
+ # We need to import here to avoid an import loop.
+ from synapse.storage.prepare_database import prepare_database
+
+ if self._is_in_memory:
+ # In memory databases need to be rebuilt each time. Ideally we'd
+ # reuse the same connection as we do when starting up, but that
+ # would involve using adbapi before we have started the reactor.
+ prepare_database(db_conn, self, config=None)
+
db_conn.create_function("rank", 1, _rank)
def is_deadlock(self, error):
@@ -85,7 +123,7 @@ class Sqlite3Engine(object):
def _parse_match_info(buf):
bufsize = len(buf)
- return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
+ return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
deleted file mode 100644
index 5dc49822b5..0000000000
--- a/synapse/storage/events_worker.py
+++ /dev/null
@@ -1,742 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import division
-
-import itertools
-import logging
-from collections import namedtuple
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
-from synapse.api.room_versions import EventFormatVersions
-from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
-# these are only included to make the type annotations work
-from synapse.events.snapshot import EventContext # noqa: F401
-from synapse.events.utils import prune_event
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import get_domain_from_id
-from synapse.util.logcontext import (
- LoggingContext,
- PreserveLoggingContext,
- make_deferred_yieldable,
- run_in_background,
-)
-from synapse.util.metrics import Measure
-
-from ._base import SQLBaseStore
-
-logger = logging.getLogger(__name__)
-
-
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
-# control how we batch/bulk fetch events from the database.
-# The values are plucked out of thing air to make initial sync run faster
-# on jki.re
-# TODO: Make these configurable.
-EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
-EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
-EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
-
-
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
-class EventsWorkerStore(SQLBaseStore):
- def get_received_ts(self, event_id):
- """Get received_ts (when it was persisted) for the event.
-
- Raises an exception for unknown events.
-
- Args:
- event_id (str)
-
- Returns:
- Deferred[int|None]: Timestamp in milliseconds, or None for events
- that were persisted before received_ts was implemented.
- """
- return self._simple_select_one_onecol(
- table="events",
- keyvalues={"event_id": event_id},
- retcol="received_ts",
- desc="get_received_ts",
- )
-
- def get_received_ts_by_stream_pos(self, stream_ordering):
- """Given a stream ordering get an approximate timestamp of when it
- happened.
-
- This is done by simply taking the received ts of the first event that
- has a stream ordering greater than or equal to the given stream pos.
- If none exists returns the current time, on the assumption that it must
- have happened recently.
-
- Args:
- stream_ordering (int)
-
- Returns:
- Deferred[int]
- """
-
- def _get_approximate_received_ts_txn(txn):
- sql = """
- SELECT received_ts FROM events
- WHERE stream_ordering >= ?
- LIMIT 1
- """
-
- txn.execute(sql, (stream_ordering,))
- row = txn.fetchone()
- if row and row[0]:
- ts = row[0]
- else:
- ts = self.clock.time_msec()
-
- return ts
-
- return self.runInteraction(
- "get_approximate_received_ts",
- _get_approximate_received_ts_txn,
- )
-
- @defer.inlineCallbacks
- def get_event(
- self,
- event_id,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
- allow_none=False,
- check_room_id=None,
- ):
- """Get an event from the database by event_id.
-
- Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
- False throw a NotFoundError
- check_room_id (str|None): if not None, check the room of the found event.
- If there is a mismatch, behave as per allow_none.
-
- Returns:
- Deferred : A FrozenEvent.
- """
- events = yield self.get_events_as_list(
- [event_id],
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- event = events[0] if events else None
-
- if event is not None and check_room_id is not None:
- if event.room_id != check_room_id:
- event = None
-
- if event is None and not allow_none:
- raise NotFoundError("Could not find event %s" % (event_id,))
-
- defer.returnValue(event)
-
- @defer.inlineCallbacks
- def get_events(
- self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
- ):
- """Get events from the database
-
- Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
-
- Returns:
- Deferred : Dict from event_id to event.
- """
- events = yield self.get_events_as_list(
- event_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- defer.returnValue({e.event_id: e for e in events})
-
- @defer.inlineCallbacks
- def get_events_as_list(
- self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
- ):
- """Get events from the database and return in a list in the same order
- as given by `event_ids` arg.
-
- Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
-
- Returns:
- Deferred[list[EventBase]]: List of events fetched from the database. The
- events are in the same order as `event_ids` arg.
-
- Note that the returned list may be smaller than the list of event
- IDs if not all events could be fetched.
- """
-
- if not event_ids:
- defer.returnValue([])
-
- event_id_list = event_ids
- event_ids = set(event_ids)
-
- event_entry_map = self._get_events_from_cache(
- event_ids, allow_rejected=allow_rejected
- )
-
- missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
- if missing_events_ids:
- log_ctx = LoggingContext.current_context()
- log_ctx.record_event_fetch(len(missing_events_ids))
-
- # Note that _enqueue_events is also responsible for turning db rows
- # into FrozenEvents (via _get_event_from_row), which involves seeing if
- # the events have been redacted, and if so pulling the redaction event out
- # of the database to check it.
- #
- # _enqueue_events is a bit of a rubbish name but naming is hard.
- missing_events = yield self._enqueue_events(
- missing_events_ids, allow_rejected=allow_rejected
- )
-
- event_entry_map.update(missing_events)
-
- events = []
- for event_id in event_id_list:
- entry = event_entry_map.get(event_id, None)
- if not entry:
- continue
-
- # Starting in room version v3, some redactions need to be rechecked if we
- # didn't have the redacted event at the time, so we recheck on read
- # instead.
- if not allow_rejected and entry.event.type == EventTypes.Redaction:
- orig_event_info = yield self._simple_select_one(
- table="events",
- keyvalues={"event_id": entry.event.redacts},
- retcols=["sender", "room_id", "type"],
- allow_none=True,
- )
-
- if not orig_event_info:
- # We don't have the event that is being redacted, so we
- # assume that the event isn't authorized for now. (If we
- # later receive the event, then we will always redact
- # it anyway, since we have this redaction)
- continue
-
- if orig_event_info["room_id"] != entry.event.room_id:
- # Don't process redactions if the redacted event doesn't belong to the
- # redaction's room.
- logger.info("Ignoring redation in another room.")
- continue
-
- if entry.event.internal_metadata.need_to_check_redaction():
- # XXX: we need to avoid calling get_event here.
- #
- # The problem is that we end up at this point when an event
- # which has been redacted is pulled out of the database by
- # _enqueue_events, because _enqueue_events needs to check
- # the redaction before it can cache the redacted event. So
- # obviously, calling get_event to get the redacted event out
- # of the database gives us an infinite loop.
- #
- # For now (quick hack to fix during 0.99 release cycle), we
- # just go and fetch the relevant row from the db, but it
- # would be nice to think about how we can cache this rather
- # than hit the db every time we access a redaction event.
- #
- # One thought on how to do this:
- # 1. split get_events_as_list up so that it is divided into
- # (a) get the rawish event from the db/cache, (b) do the
- # redaction/rejection filtering
- # 2. have _get_event_from_row just call the first half of
- # that
-
- expected_domain = get_domain_from_id(entry.event.sender)
- if (
- get_domain_from_id(orig_event_info["sender"]) == expected_domain
- ):
- # This redaction event is allowed. Mark as not needing a
- # recheck.
- entry.event.internal_metadata.recheck_redaction = False
-
- if allow_rejected or not entry.event.rejected_reason:
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
-
- events.append(event)
-
- if get_prev_content:
- if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
- event.unsigned["replaces_state"],
- get_prev_content=False,
- allow_none=True,
- )
- if prev:
- event.unsigned = dict(event.unsigned)
- event.unsigned["prev_content"] = prev.content
- event.unsigned["prev_sender"] = prev.sender
-
- defer.returnValue(events)
-
- def _invalidate_get_event_cache(self, event_id):
- self._get_event_cache.invalidate((event_id,))
-
- def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
- """Fetch events from the caches
-
- Args:
- events (list(str)): list of event_ids to fetch
- allow_rejected (bool): Whether to teturn events that were rejected
- update_metrics (bool): Whether to update the cache hit ratio metrics
-
- Returns:
- dict of event_id -> _EventCacheEntry for each event_id in cache. If
- allow_rejected is `False` then there will still be an entry but it
- will be `None`
- """
- event_map = {}
-
- for event_id in events:
- ret = self._get_event_cache.get(
- (event_id,), None, update_metrics=update_metrics
- )
- if not ret:
- continue
-
- if allow_rejected or not ret.event.rejected_reason:
- event_map[event_id] = ret
- else:
- event_map[event_id] = None
-
- return event_map
-
- def _do_fetch(self, conn):
- """Takes a database connection and waits for requests for events from
- the _event_fetch_list queue.
- """
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- self._fetch_event_list(conn, event_list)
-
- def _fetch_event_list(self, conn, event_list):
- """Handle a load of requests from the _event_fetch_list queue
-
- Args:
- conn (twisted.enterprise.adbapi.Connection): database connection
-
- event_list (list[Tuple[list[str], Deferred]]):
- The fetch requests. Each entry consists of a list of event
- ids to be fetched, and a deferred to be completed once the
- events have been fetched.
-
- """
- with Measure(self._clock, "_fetch_event_list"):
- try:
- event_id_lists = list(zip(*event_list))[0]
- event_ids = [item for sublist in event_id_lists for item in sublist]
-
- rows = self._new_transaction(
- conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
- )
-
- row_dict = {r["event_id"]: r for r in rows}
-
- # We only want to resolve deferreds from the main thread
- def fire(lst, res):
- for ids, d in lst:
- if not d.called:
- try:
- with PreserveLoggingContext():
- d.callback([res[i] for i in ids if i in res])
- except Exception:
- logger.exception("Failed to callback")
-
- with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
- except Exception as e:
- logger.exception("do_fetch")
-
- # We only want to resolve deferreds from the main thread
- def fire(evs, exc):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(exc)
-
- with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, e)
-
- @defer.inlineCallbacks
- def _enqueue_events(self, events, allow_rejected=False):
- """Fetches events from the database using the _event_fetch_list. This
- allows batch and bulk fetching of events - it allows us to fetch events
- without having to create a new transaction for each request for events.
- """
- if not events:
- defer.returnValue({})
-
- events_d = defer.Deferred()
- with self._event_fetch_lock:
- self._event_fetch_list.append((events, events_d))
-
- self._event_fetch_lock.notify()
-
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- should_start = True
- else:
- should_start = False
-
- if should_start:
- run_as_background_process(
- "fetch_events", self.runWithConnection, self._do_fetch
- )
-
- logger.debug("Loading %d events", len(events))
- with PreserveLoggingContext():
- rows = yield events_d
- logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
- if not allow_rejected:
- rows[:] = [r for r in rows if not r["rejects"]]
-
- res = yield make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self._get_event_from_row,
- row["internal_metadata"],
- row["json"],
- row["redacts"],
- rejected_reason=row["rejects"],
- format_version=row["format_version"],
- )
- for row in rows
- ],
- consumeErrors=True,
- )
- )
-
- defer.returnValue({e.event.event_id: e for e in res if e})
-
- def _fetch_event_rows(self, txn, events):
- rows = []
- N = 200
- for i in range(1 + len(events) // N):
- evs = events[i * N : (i + 1) * N]
- if not evs:
- break
-
- sql = (
- "SELECT "
- " e.event_id as event_id, "
- " e.internal_metadata,"
- " e.json,"
- " e.format_version, "
- " r.redacts as redacts,"
- " rej.event_id as rejects "
- " FROM event_json as e"
- " LEFT JOIN rejections as rej USING (event_id)"
- " LEFT JOIN redactions as r ON e.event_id = r.redacts"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"] * len(evs)),)
-
- txn.execute(sql, evs)
- rows.extend(self.cursor_to_dict(txn))
-
- return rows
-
- @defer.inlineCallbacks
- def _get_event_from_row(
- self, internal_metadata, js, redacted, format_version, rejected_reason=None
- ):
- with Measure(self._clock, "_get_event_from_row"):
- d = json.loads(js)
- internal_metadata = json.loads(internal_metadata)
-
- if rejected_reason:
- rejected_reason = yield self._simple_select_one_onecol(
- table="rejections",
- keyvalues={"event_id": rejected_reason},
- retcol="reason",
- desc="_get_event_from_row_rejected_reason",
- )
-
- if format_version is None:
- # This means that we stored the event before we had the concept
- # of a event format version, so it must be a V1 event.
- format_version = EventFormatVersions.V1
-
- original_ev = event_type_from_format_version(format_version)(
- event_dict=d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
- )
-
- redacted_event = None
- if redacted and original_ev.type != EventTypes.Redaction:
- redacted_event = prune_event(original_ev)
-
- redaction_id = yield self._simple_select_one_onecol(
- table="redactions",
- keyvalues={"redacts": redacted_event.event_id},
- retcol="event_id",
- desc="_get_event_from_row_redactions",
- )
-
- redacted_event.unsigned["redacted_by"] = redaction_id
- # Get the redaction event.
-
- because = yield self.get_event(
- redaction_id, check_redacted=False, allow_none=True
- )
-
- if because:
- # It's fine to do add the event directly, since get_pdu_json
- # will serialise this field correctly
- redacted_event.unsigned["redacted_because"] = because
-
- # Starting in room version v3, some redactions need to be
- # rechecked if we didn't have the redacted event at the
- # time, so we recheck on read instead.
- if because.internal_metadata.need_to_check_redaction():
- expected_domain = get_domain_from_id(original_ev.sender)
- if get_domain_from_id(because.sender) == expected_domain:
- # This redaction event is allowed. Mark as not needing a
- # recheck.
- because.internal_metadata.recheck_redaction = False
- else:
- # Senders don't match, so the event isn't actually
- # redacted
- redacted_event = None
-
- if because.room_id != original_ev.room_id:
- redacted_event = None
- else:
- # The lack of a redaction likely means that the redaction is invalid
- # and therefore not returned by get_event, so it should be safe to
- # just ignore it here.
- redacted_event = None
-
- cache_entry = _EventCacheEntry(
- event=original_ev, redacted_event=redacted_event
- )
-
- self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
- defer.returnValue(cache_entry)
-
- @defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
- """Given a list of event ids, check if we have already processed and
- stored them as non outliers.
- """
- rows = yield self._simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
- )
-
- defer.returnValue(set(r["event_id"] for r in rows))
-
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
- """Given a list of event ids, check if we have already processed them.
-
- Args:
- event_ids (iterable[str]):
-
- Returns:
- Deferred[set[str]]: The events we have already seen.
- """
- results = set()
-
- def have_seen_events_txn(txn, chunk):
- sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
- ",".join("?" * len(chunk)),
- )
- txn.execute(sql, chunk)
- for (event_id,) in txn:
- results.add(event_id)
-
- # break the input up into chunks of 100
- input_iterator = iter(event_ids)
- for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
- yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
- defer.returnValue(results)
-
- def get_seen_events_with_rejections(self, event_ids):
- """Given a list of event ids, check if we rejected them.
-
- Args:
- event_ids (list[str])
-
- Returns:
- Deferred[dict[str, str|None):
- Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps
- to None.
- """
- if not event_ids:
- return defer.succeed({})
-
- def f(txn):
- sql = (
- "SELECT e.event_id, reason FROM events as e "
- "LEFT JOIN rejections as r ON e.event_id = r.event_id "
- "WHERE e.event_id = ?"
- )
-
- res = {}
- for event_id in event_ids:
- txn.execute(sql, (event_id,))
- row = txn.fetchone()
- if row:
- _, rejected = row
- res[event_id] = rejected
-
- return res
-
- return self.runInteraction("get_seen_events_with_rejections", f)
-
- def _get_total_state_event_counts_txn(self, txn, room_id):
- """
- See get_total_state_event_counts.
- """
- # We join against the events table as that has an index on room_id
- sql = """
- SELECT COUNT(*) FROM state_events
- INNER JOIN events USING (room_id, event_id)
- WHERE room_id=?
- """
- txn.execute(sql, (room_id,))
- row = txn.fetchone()
- return row[0] if row else 0
-
- def get_total_state_event_counts(self, room_id):
- """
- Gets the total number of state events in a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[int]
- """
- return self.runInteraction(
- "get_total_state_event_counts",
- self._get_total_state_event_counts_txn, room_id
- )
-
- def _get_current_state_event_counts_txn(self, txn, room_id):
- """
- See get_current_state_event_counts.
- """
- sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?"
- txn.execute(sql, (room_id,))
- row = txn.fetchone()
- return row[0] if row else 0
-
- def get_current_state_event_counts(self, room_id):
- """
- Gets the current number of state events in a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[int]
- """
- return self.runInteraction(
- "get_current_state_event_counts",
- self._get_current_state_event_counts_txn, room_id
- )
-
- @defer.inlineCallbacks
- def get_room_complexity(self, room_id):
- """
- Get a rough approximation of the complexity of the room. This is used by
- remote servers to decide whether they wish to join the room or not.
- Higher complexity value indicates that being in the room will consume
- more resources.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[dict[str:int]] of complexity version to complexity.
- """
- state_events = yield self.get_current_state_event_counts(room_id)
-
- # Call this one "v1", so we can introduce new ones as we want to develop
- # it.
- complexity_v1 = round(state_events / 500, 2)
-
- defer.returnValue({"v1": complexity_v1})
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index e3655ad8d7..4769b21529 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -14,208 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
-import six
-
import attr
-from signedjson.key import decode_verify_key_bytes
-
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedList
-
-from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
-
@attr.s(slots=True, frozen=True)
class FetchKeyResult(object):
verify_key = attr.ib() # VerifyKey: the key itself
valid_until_ts = attr.ib() # int: how long we can use this key for
-
-
-class KeyStore(SQLBaseStore):
- """Persistence for signature verification keys
- """
-
- @cached()
- def _get_server_verify_key(self, server_name_and_key_id):
- raise NotImplementedError()
-
- @cachedList(
- cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
- )
- def get_server_verify_keys(self, server_name_and_key_ids):
- """
- Args:
- server_name_and_key_ids (iterable[Tuple[str, str]]):
- iterable of (server_name, key-id) tuples to fetch keys for
-
- Returns:
- Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
- map from (server_name, key_id) -> FetchKeyResult, or None if the key is
- unknown
- """
- keys = {}
-
- def _get_keys(txn, batch):
- """Processes a batch of keys to fetch, and adds the result to `keys`."""
-
- # batch_iter always returns tuples so it's safe to do len(batch)
- sql = (
- "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
- "FROM server_signature_keys WHERE 1=0"
- ) + " OR (server_name=? AND key_id=?)" * len(batch)
-
- txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
-
- for row in txn:
- server_name, key_id, key_bytes, ts_valid_until_ms = row
-
- if ts_valid_until_ms is None:
- # Old keys may be stored with a ts_valid_until_ms of null,
- # in which case we treat this as if it was set to `0`, i.e.
- # it won't match key requests that define a minimum
- # `ts_valid_until_ms`.
- ts_valid_until_ms = 0
-
- res = FetchKeyResult(
- verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
- valid_until_ts=ts_valid_until_ms,
- )
- keys[(server_name, key_id)] = res
-
- def _txn(txn):
- for batch in batch_iter(server_name_and_key_ids, 50):
- _get_keys(txn, batch)
- return keys
-
- return self.runInteraction("get_server_verify_keys", _txn)
-
- def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
- """Stores NACL verification keys for remote servers.
- Args:
- from_server (str): Where the verification keys were looked up
- ts_added_ms (int): The time to record that the key was added
- verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
- keys to be stored. Each entry is a triplet of
- (server_name, key_id, key).
- """
- key_values = []
- value_values = []
- invalidations = []
- for server_name, key_id, fetch_result in verify_keys:
- key_values.append((server_name, key_id))
- value_values.append(
- (
- from_server,
- ts_added_ms,
- fetch_result.valid_until_ts,
- db_binary_type(fetch_result.verify_key.encode()),
- )
- )
- # invalidate takes a tuple corresponding to the params of
- # _get_server_verify_key. _get_server_verify_key only takes one
- # param, which is itself the 2-tuple (server_name, key_id).
- invalidations.append((server_name, key_id))
-
- def _invalidate(res):
- f = self._get_server_verify_key.invalidate
- for i in invalidations:
- f((i, ))
- return res
-
- return self.runInteraction(
- "store_server_verify_keys",
- self._simple_upsert_many_txn,
- table="server_signature_keys",
- key_names=("server_name", "key_id"),
- key_values=key_values,
- value_names=(
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "verify_key",
- ),
- value_values=value_values,
- ).addCallback(_invalidate)
-
- def store_server_keys_json(
- self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
- ):
- """Stores the JSON bytes for a set of keys from a server
- The JSON should be signed by the originating server, the intermediate
- server, and by this server. Updates the value for the
- (server_name, key_id, from_server) triplet if one already existed.
- Args:
- server_name (str): The name of the server.
- key_id (str): The identifer of the key this JSON is for.
- from_server (str): The server this JSON was fetched from.
- ts_now_ms (int): The time now in milliseconds.
- ts_valid_until_ms (int): The time when this json stops being valid.
- key_json (bytes): The encoded JSON.
- """
- return self._simple_upsert(
- table="server_keys_json",
- keyvalues={
- "server_name": server_name,
- "key_id": key_id,
- "from_server": from_server,
- },
- values={
- "server_name": server_name,
- "key_id": key_id,
- "from_server": from_server,
- "ts_added_ms": ts_now_ms,
- "ts_valid_until_ms": ts_expires_ms,
- "key_json": db_binary_type(key_json_bytes),
- },
- desc="store_server_keys_json",
- )
-
- def get_server_keys_json(self, server_keys):
- """Retrive the key json for a list of server_keys and key ids.
- If no keys are found for a given server, key_id and source then
- that server, key_id, and source triplet entry will be an empty list.
- The JSON is returned as a byte array so that it can be efficiently
- used in an HTTP response.
- Args:
- server_keys (list): List of (server_name, key_id, source) triplets.
- Returns:
- Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
- Dict mapping (server_name, key_id, source) triplets to lists of dicts
- """
-
- def _get_server_keys_json_txn(txn):
- results = {}
- for server_name, key_id, from_server in server_keys:
- keyvalues = {"server_name": server_name}
- if key_id is not None:
- keyvalues["key_id"] = key_id
- if from_server is not None:
- keyvalues["from_server"] = from_server
- rows = self._simple_select_list_txn(
- txn,
- "server_keys_json",
- keyvalues=keyvalues,
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
- ),
- )
- results[(server_name, key_id, from_server)] = rows
- return results
-
- return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
new file mode 100644
index 0000000000..0f9ac1cf09
--- /dev/null
+++ b/synapse/storage/persist_events.py
@@ -0,0 +1,801 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+from collections import deque, namedtuple
+from typing import Iterable, List, Optional, Set, Tuple
+
+from six import iteritems
+from six.moves import range
+
+import attr
+from prometheus_client import Counter, Histogram
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.state import StateResolutionStore
+from synapse.storage.data_stores import DataStores
+from synapse.types import StateMap
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+ "synapse_storage_events_state_delta_single_event", ""
+)
+
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+ "synapse_storage_events_state_delta_reuse_delta", ""
+)
+
+# The number of forward extremities for each new event.
+forward_extremities_counter = Histogram(
+ "synapse_storage_events_forward_extremities_persisted",
+ "Number of forward extremities for each new event",
+ buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+# The number of stale forward extremities for each new event. Stale extremities
+# are those that were in the previous set of extremities as well as the new.
+stale_forward_extremities_counter = Histogram(
+ "synapse_storage_events_stale_forward_extremities_persisted",
+ "Number of unchanged forward extremities for each new event",
+ buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+
+@attr.s(slots=True)
+class DeltaState:
+ """Deltas to use to update the `current_state_events` table.
+
+ Attributes:
+ to_delete: List of type/state_keys to delete from current state
+ to_insert: Map of state to upsert into current state
+ no_longer_in_room: The server is not longer in the room, so the room
+ should e.g. be removed from `current_state_events` table.
+ """
+
+ to_delete = attr.ib(type=List[Tuple[str, str]])
+ to_insert = attr.ib(type=StateMap[str])
+ no_longer_in_room = attr.ib(type=bool, default=False)
+
+
+class _EventPeristenceQueue(object):
+ """Queues up events so that they can be persisted in bulk with only one
+ concurrent transaction per room.
+ """
+
+ _EventPersistQueueItem = namedtuple(
+ "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
+ )
+
+ def __init__(self):
+ self._event_persist_queues = {}
+ self._currently_persisting_rooms = set()
+
+ def add_to_queue(self, room_id, events_and_contexts, backfilled):
+ """Add events to the queue, with the given persist_event options.
+
+ NB: due to the normal usage pattern of this method, it does *not*
+ follow the synapse logcontext rules, and leaves the logcontext in
+ place whether or not the returned deferred is ready.
+
+ Args:
+ room_id (str):
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+
+ Returns:
+ defer.Deferred: a deferred which will resolve once the events are
+ persisted. Runs its callbacks *without* a logcontext.
+ """
+ queue = self._event_persist_queues.setdefault(room_id, deque())
+ if queue:
+ # if the last item in the queue has the same `backfilled` setting,
+ # we can just add these new events to that item.
+ end_item = queue[-1]
+ if end_item.backfilled == backfilled:
+ end_item.events_and_contexts.extend(events_and_contexts)
+ return end_item.deferred.observe()
+
+ deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+
+ queue.append(
+ self._EventPersistQueueItem(
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ deferred=deferred,
+ )
+ )
+
+ return deferred.observe()
+
+ def handle_queue(self, room_id, per_item_callback):
+ """Attempts to handle the queue for a room if not already being handled.
+
+ The given callback will be invoked with for each item in the queue,
+ of type _EventPersistQueueItem. The per_item_callback will continuously
+ be called with new items, unless the queue becomnes empty. The return
+ value of the function will be given to the deferreds waiting on the item,
+ exceptions will be passed to the deferreds as well.
+
+ This function should therefore be called whenever anything is added
+ to the queue.
+
+ If another callback is currently handling the queue then it will not be
+ invoked.
+ """
+
+ if room_id in self._currently_persisting_rooms:
+ return
+
+ self._currently_persisting_rooms.add(room_id)
+
+ async def handle_queue_loop():
+ try:
+ queue = self._get_drainining_queue(room_id)
+ for item in queue:
+ try:
+ ret = await per_item_callback(item)
+ except Exception:
+ with PreserveLoggingContext():
+ item.deferred.errback()
+ else:
+ with PreserveLoggingContext():
+ item.deferred.callback(ret)
+ finally:
+ queue = self._event_persist_queues.pop(room_id, None)
+ if queue:
+ self._event_persist_queues[room_id] = queue
+ self._currently_persisting_rooms.discard(room_id)
+
+ # set handle_queue_loop off in the background
+ run_as_background_process("persist_events", handle_queue_loop)
+
+ def _get_drainining_queue(self, room_id):
+ queue = self._event_persist_queues.setdefault(room_id, deque())
+
+ try:
+ while True:
+ yield queue.popleft()
+ except IndexError:
+ # Queue has been drained.
+ pass
+
+
+class EventsPersistenceStorage(object):
+ """High level interface for handling persisting newly received events.
+
+ Takes care of batching up events by room, and calculating the necessary
+ current state and forward extremity changes.
+ """
+
+ def __init__(self, hs, stores: DataStores):
+ # We ultimately want to split out the state store from the main store,
+ # so we use separate variables here even though they point to the same
+ # store for now.
+ self.main_store = stores.main
+ self.state_store = stores.state
+
+ self._clock = hs.get_clock()
+ self.is_mine_id = hs.is_mine_id
+ self._event_persist_queue = _EventPeristenceQueue()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
+ @defer.inlineCallbacks
+ def persist_events(
+ self,
+ events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+ backfilled: bool = False,
+ ):
+ """
+ Write events to the database
+ Args:
+ events_and_contexts: list of tuples of (event, context)
+ backfilled: Whether the results are retrieved from federation
+ via backfill or not. Used to determine if they're "new" events
+ which might update the current state etc.
+
+ Returns:
+ Deferred[int]: the stream ordering of the latest persisted event
+ """
+ partitioned = {}
+ for event, ctx in events_and_contexts:
+ partitioned.setdefault(event.room_id, []).append((event, ctx))
+
+ deferreds = []
+ for room_id, evs_ctxs in iteritems(partitioned):
+ d = self._event_persist_queue.add_to_queue(
+ room_id, evs_ctxs, backfilled=backfilled
+ )
+ deferreds.append(d)
+
+ for room_id in partitioned:
+ self._maybe_start_persisting(room_id)
+
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
+
+ max_persisted_id = yield self.main_store.get_current_events_token()
+
+ return max_persisted_id
+
+ @defer.inlineCallbacks
+ def persist_event(
+ self, event: FrozenEvent, context: EventContext, backfilled: bool = False
+ ):
+ """
+ Returns:
+ Deferred[Tuple[int, int]]: the stream ordering of ``event``,
+ and the stream ordering of the latest persisted event
+ """
+ deferred = self._event_persist_queue.add_to_queue(
+ event.room_id, [(event, context)], backfilled=backfilled
+ )
+
+ self._maybe_start_persisting(event.room_id)
+
+ yield make_deferred_yieldable(deferred)
+
+ max_persisted_id = yield self.main_store.get_current_events_token()
+ return (event.internal_metadata.stream_ordering, max_persisted_id)
+
+ def _maybe_start_persisting(self, room_id: str):
+ async def persisting_queue(item):
+ with Measure(self._clock, "persist_events"):
+ await self._persist_events(
+ item.events_and_contexts, backfilled=item.backfilled
+ )
+
+ self._event_persist_queue.handle_queue(room_id, persisting_queue)
+
+ async def _persist_events(
+ self,
+ events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+ backfilled: bool = False,
+ ):
+ """Calculates the change to current state and forward extremities, and
+ persists the given events and with those updates.
+ """
+ if not events_and_contexts:
+ return
+
+ chunks = [
+ events_and_contexts[x : x + 100]
+ for x in range(0, len(events_and_contexts), 100)
+ ]
+
+ for chunk in chunks:
+ # We can't easily parallelize these since different chunks
+ # might contain the same event. :(
+
+ # NB: Assumes that we are only persisting events for one room
+ # at a time.
+
+ # map room_id->list[event_ids] giving the new forward
+ # extremities in each room
+ new_forward_extremeties = {}
+
+ # map room_id->(type,state_key)->event_id tracking the full
+ # state in each room after adding these events.
+ # This is simply used to prefill the get_current_state_ids
+ # cache
+ current_state_for_room = {}
+
+ # map room_id->(to_delete, to_insert) where to_delete is a list
+ # of type/state keys to remove from current state, and to_insert
+ # is a map (type,key)->event_id giving the state delta in each
+ # room
+ state_delta_for_room = {}
+
+ # Set of remote users which were in rooms the server has left. We
+ # should check if we still share any rooms and if not we mark their
+ # device lists as stale.
+ potentially_left_users = set() # type: Set[str]
+
+ if not backfilled:
+ with Measure(self._clock, "_calculate_state_and_extrem"):
+ # Work out the new "current state" for each room.
+ # We do this by working out what the new extremities are and then
+ # calculating the state from that.
+ events_by_room = {}
+ for event, context in chunk:
+ events_by_room.setdefault(event.room_id, []).append(
+ (event, context)
+ )
+
+ for room_id, ev_ctx_rm in iteritems(events_by_room):
+ latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
+ room_id
+ )
+ new_latest_event_ids = await self._calculate_new_extremities(
+ room_id, ev_ctx_rm, latest_event_ids
+ )
+
+ latest_event_ids = set(latest_event_ids)
+ if new_latest_event_ids == latest_event_ids:
+ # No change in extremities, so no change in state
+ continue
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremeties[room_id] = new_latest_event_ids
+
+ len_1 = (
+ len(latest_event_ids) == 1
+ and len(new_latest_event_ids) == 1
+ )
+ if len_1:
+ all_single_prev_not_state = all(
+ len(event.prev_event_ids()) == 1
+ and not event.is_state()
+ for event, ctx in ev_ctx_rm
+ )
+ # Don't bother calculating state if they're just
+ # a long chain of single ancestor non-state events.
+ if all_single_prev_not_state:
+ continue
+
+ state_delta_counter.inc()
+ if len(new_latest_event_ids) == 1:
+ state_delta_single_event_counter.inc()
+
+ # This is a fairly handwavey check to see if we could
+ # have guessed what the delta would have been when
+ # processing one of these events.
+ # What we're interested in is if the latest extremities
+ # were the same when we created the event as they are
+ # now. When this server creates a new event (as opposed
+ # to receiving it over federation) it will use the
+ # forward extremities as the prev_events, so we can
+ # guess this by looking at the prev_events and checking
+ # if they match the current forward extremities.
+ for ev, _ in ev_ctx_rm:
+ prev_event_ids = set(ev.prev_event_ids())
+ if latest_event_ids == prev_event_ids:
+ state_delta_reuse_delta_counter.inc()
+ break
+
+ logger.debug("Calculating state delta for room %s", room_id)
+ with Measure(
+ self._clock, "persist_events.get_new_state_after_events"
+ ):
+ res = await self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
+ )
+ current_state, delta_ids = res
+
+ # If either are not None then there has been a change,
+ # and we need to work out the delta (or use that
+ # given)
+ delta = None
+ if delta_ids is not None:
+ # If there is a delta we know that we've
+ # only added or replaced state, never
+ # removed keys entirely.
+ delta = DeltaState([], delta_ids)
+ elif current_state is not None:
+ with Measure(
+ self._clock, "persist_events.calculate_state_delta"
+ ):
+ delta = await self._calculate_state_delta(
+ room_id, current_state
+ )
+
+ if delta:
+ # If we have a change of state then lets check
+ # whether we're actually still a member of the room,
+ # or if our last user left. If we're no longer in
+ # the room then we delete the current state and
+ # extremities.
+ is_still_joined = await self._is_server_still_joined(
+ room_id,
+ ev_ctx_rm,
+ delta,
+ current_state,
+ potentially_left_users,
+ )
+ if not is_still_joined:
+ logger.info("Server no longer in room %s", room_id)
+ latest_event_ids = []
+ current_state = {}
+ delta.no_longer_in_room = True
+
+ state_delta_for_room[room_id] = delta
+
+ # If we have the current_state then lets prefill
+ # the cache with it.
+ if current_state is not None:
+ current_state_for_room[room_id] = current_state
+
+ await self.main_store._persist_events_and_state_updates(
+ chunk,
+ current_state_for_room=current_state_for_room,
+ state_delta_for_room=state_delta_for_room,
+ new_forward_extremeties=new_forward_extremeties,
+ backfilled=backfilled,
+ )
+
+ await self._handle_potentially_left_users(potentially_left_users)
+
+ async def _calculate_new_extremities(
+ self,
+ room_id: str,
+ event_contexts: List[Tuple[FrozenEvent, EventContext]],
+ latest_event_ids: List[str],
+ ):
+ """Calculates the new forward extremities for a room given events to
+ persist.
+
+ Assumes that we are only persisting events for one room at a time.
+ """
+
+ # we're only interested in new events which aren't outliers and which aren't
+ # being rejected.
+ new_events = [
+ event
+ for event, ctx in event_contexts
+ if not event.internal_metadata.is_outlier()
+ and not ctx.rejected
+ and not event.internal_metadata.is_soft_failed()
+ ]
+
+ latest_event_ids = set(latest_event_ids)
+
+ # start with the existing forward extremities
+ result = set(latest_event_ids)
+
+ # add all the new events to the list
+ result.update(event.event_id for event in new_events)
+
+ # Now remove all events which are prev_events of any of the new events
+ result.difference_update(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+
+ # Remove any events which are prev_events of any existing events.
+ existing_prevs = await self.main_store._get_events_which_are_prevs(result)
+ result.difference_update(existing_prevs)
+
+ # Finally handle the case where the new events have soft-failed prev
+ # events. If they do we need to remove them and their prev events,
+ # otherwise we end up with dangling extremities.
+ existing_prevs = await self.main_store._get_prevs_before_rejected(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+ result.difference_update(existing_prevs)
+
+ # We only update metrics for events that change forward extremities
+ # (e.g. we ignore backfill/outliers/etc)
+ if result != latest_event_ids:
+ forward_extremities_counter.observe(len(result))
+ stale = latest_event_ids & result
+ stale_forward_extremities_counter.observe(len(stale))
+
+ return result
+
+ async def _get_new_state_after_events(
+ self,
+ room_id: str,
+ events_context: List[Tuple[FrozenEvent, EventContext]],
+ old_latest_event_ids: Iterable[str],
+ new_latest_event_ids: Iterable[str],
+ ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
+ """Calculate the current state dict after adding some new events to
+ a room
+
+ Args:
+ room_id (str):
+ room to which the events are being added. Used for logging etc
+
+ events_context (list[(EventBase, EventContext)]):
+ events and contexts which are being added to the room
+
+ old_latest_event_ids (iterable[str]):
+ the old forward extremities for the room.
+
+ new_latest_event_ids (iterable[str]):
+ the new forward extremities for the room.
+
+ Returns:
+ Returns a tuple of two state maps, the first being the full new current
+ state and the second being the delta to the existing current state.
+ If both are None then there has been no change.
+
+ If there has been a change then we only return the delta if its
+ already been calculated. Conversely if we do know the delta then
+ the new current state is only returned if we've already calculated
+ it.
+ """
+ # map from state_group to ((type, key) -> event_id) state map
+ state_groups_map = {}
+
+ # Map from (prev state group, new state group) -> delta state dict
+ state_group_deltas = {}
+
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # This should only happen for outlier events.
+ if not ev.internal_metadata.is_outlier():
+ raise Exception(
+ "Context for new event %s has no state "
+ "group" % (ev.event_id,)
+ )
+ continue
+
+ if ctx.state_group in state_groups_map:
+ continue
+
+ # We're only interested in pulling out state that has already
+ # been cached in the context. We'll pull stuff out of the DB later
+ # if necessary.
+ current_state_ids = ctx.get_cached_current_state_ids()
+ if current_state_ids is not None:
+ state_groups_map[ctx.state_group] = current_state_ids
+
+ if ctx.prev_group:
+ state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
+ for event_id in new_latest_event_ids:
+ # First search in the list of new events we're adding.
+ for ev, ctx in events_context:
+ if event_id == ev.event_id and ctx.state_group is not None:
+ event_id_to_state_group[event_id] = ctx.state_group
+ break
+ else:
+ # If we couldn't find it, then we'll need to pull
+ # the state from the database
+ missing_event_ids.add(event_id)
+
+ if missing_event_ids:
+ # Now pull out the state groups for any missing events from DB
+ event_to_groups = await self.main_store._get_state_group_for_events(
+ missing_event_ids
+ )
+ event_id_to_state_group.update(event_to_groups)
+
+ # State groups of old_latest_event_ids
+ old_state_groups = {
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ }
+
+ # State groups of new_latest_event_ids
+ new_state_groups = {
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ }
+
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
+ return None, None
+
+ if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+ # If we're going from one state group to another, lets check if
+ # we have a delta for that transition. If we do then we can just
+ # return that.
+
+ new_state_group = next(iter(new_state_groups))
+ old_state_group = next(iter(old_state_groups))
+
+ delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
+ if delta_ids is not None:
+ # We have a delta from the existing to new current state,
+ # so lets just return that. If we happen to already have
+ # the current state in memory then lets also return that,
+ # but it doesn't matter if we don't.
+ new_state = state_groups_map.get(new_state_group)
+ return new_state, delta_ids
+
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ missing_state = new_state_groups - set(state_groups_map)
+ if missing_state:
+ group_to_state = await self.state_store._get_state_for_groups(missing_state)
+ state_groups_map.update(group_to_state)
+
+ if len(new_state_groups) == 1:
+ # If there is only one state group, then we know what the current
+ # state is.
+ return state_groups_map[new_state_groups.pop()], None
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
+
+ state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
+
+ events_map = {ev.event_id: ev for ev, _ in events_context}
+
+ # We need to get the room version, which is in the create event.
+ # Normally that'd be in the database, but its also possible that we're
+ # currently trying to persist it.
+ room_version = None
+ for ev, _ in events_context:
+ if ev.type == EventTypes.Create and ev.state_key == "":
+ room_version = ev.content.get("room_version", "1")
+ break
+
+ if not room_version:
+ room_version = await self.main_store.get_room_version_id(room_id)
+
+ logger.debug("calling resolve_state_groups from preserve_events")
+ res = await self._state_resolution_handler.resolve_state_groups(
+ room_id,
+ room_version,
+ state_groups,
+ events_map,
+ state_res_store=StateResolutionStore(self.main_store),
+ )
+
+ return res.state, None
+
+ async def _calculate_state_delta(
+ self, room_id: str, current_state: StateMap[str]
+ ) -> DeltaState:
+ """Calculate the new state deltas for a room.
+
+ Assumes that we are only persisting events for one room at a time.
+ """
+ existing_state = await self.main_store.get_current_state_ids(room_id)
+
+ to_delete = [key for key in existing_state if key not in current_state]
+
+ to_insert = {
+ key: ev_id
+ for key, ev_id in iteritems(current_state)
+ if ev_id != existing_state.get(key)
+ }
+
+ return DeltaState(to_delete=to_delete, to_insert=to_insert)
+
+ async def _is_server_still_joined(
+ self,
+ room_id: str,
+ ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
+ delta: DeltaState,
+ current_state: Optional[StateMap[str]],
+ potentially_left_users: Set[str],
+ ) -> bool:
+ """Check if the server will still be joined after the given events have
+ been persised.
+
+ Args:
+ room_id
+ ev_ctx_rm
+ delta: The delta of current state between what is in the database
+ and what the new current state will be.
+ current_state: The new current state if it already been calculated,
+ otherwise None.
+ potentially_left_users: If the server has left the room, then joined
+ remote users will be added to this set to indicate that the
+ server may no longer be sharing a room with them.
+ """
+
+ if not any(
+ self.is_mine_id(state_key)
+ for typ, state_key in itertools.chain(delta.to_delete, delta.to_insert)
+ if typ == EventTypes.Member
+ ):
+ # There have been no changes to membership of our users, so nothing
+ # has changed and we assume we're still in the room.
+ return True
+
+ # Check if any of the given events are a local join that appear in the
+ # current state
+ events_to_check = [] # Event IDs that aren't an event we're persisting
+ for (typ, state_key), event_id in delta.to_insert.items():
+ if typ != EventTypes.Member or not self.is_mine_id(state_key):
+ continue
+
+ for event, _ in ev_ctx_rm:
+ if event_id == event.event_id:
+ if event.membership == Membership.JOIN:
+ return True
+
+ # The event is not in `ev_ctx_rm`, so we need to pull it out of
+ # the DB.
+ events_to_check.append(event_id)
+
+ # Check if any of the changes that we don't have events for are joins.
+ if events_to_check:
+ rows = await self.main_store.get_membership_from_event_ids(events_to_check)
+ is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
+ if is_still_joined:
+ return True
+
+ # None of the new state events are local joins, so we check the database
+ # to see if there are any other local users in the room. We ignore users
+ # whose state has changed as we've already their new state above.
+ users_to_ignore = [
+ state_key
+ for _, state_key in itertools.chain(delta.to_insert, delta.to_delete)
+ if self.is_mine_id(state_key)
+ ]
+
+ if await self.main_store.is_local_host_in_room_ignoring_users(
+ room_id, users_to_ignore
+ ):
+ return True
+
+ # The server will leave the room, so we go and find out which remote
+ # users will still be joined when we leave.
+ if current_state is None:
+ current_state = await self.main_store.get_current_state_ids(room_id)
+ current_state = dict(current_state)
+ for key in delta.to_delete:
+ current_state.pop(key, None)
+
+ current_state.update(delta.to_insert)
+
+ remote_event_ids = [
+ event_id
+ for (typ, state_key,), event_id in current_state.items()
+ if typ == EventTypes.Member and not self.is_mine_id(state_key)
+ ]
+ rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+ potentially_left_users.update(
+ row["user_id"] for row in rows if row["membership"] == Membership.JOIN
+ )
+
+ return False
+
+ async def _handle_potentially_left_users(self, user_ids: Set[str]):
+ """Given a set of remote users check if the server still shares a room with
+ them. If not then mark those users' device cache as stale.
+ """
+
+ if not user_ids:
+ return
+
+ joined_users = await self.main_store.get_users_server_still_shares_room_with(
+ user_ids
+ )
+ left_users = user_ids - joined_users
+
+ for user_id in left_users:
+ await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index f2c1bed487..6cb7d4b922 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import fnmatch
import imp
import logging
import os
import re
+from collections import Counter
+
+import attr
from synapse.storage.engines.postgres import PostgresEngine
@@ -27,7 +29,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 55
+SCHEMA_VERSION = 57
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -40,7 +42,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -53,7 +55,10 @@ def prepare_database(db_conn, database_engine, config):
config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing
database which we expect to be configured already
+ data_stores (list[str]): The name of the data stores that will be used
+ with this database. Defaults to all data stores.
"""
+
try:
cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur, database_engine)
@@ -65,13 +70,22 @@ def prepare_database(db_conn, database_engine, config):
if user_version != SCHEMA_VERSION:
# If we don't pass in a config file then we are expecting to
# have already upgraded the DB.
- raise UpgradeDatabaseException("Database needs to be upgraded")
+ raise UpgradeDatabaseException(
+ "Expected database schema version %i but got %i"
+ % (SCHEMA_VERSION, user_version)
+ )
else:
_upgrade_existing_database(
- cur, user_version, delta_files, upgraded, database_engine, config
+ cur,
+ user_version,
+ delta_files,
+ upgraded,
+ database_engine,
+ config,
+ data_stores=data_stores,
)
else:
- _setup_new_database(cur, database_engine)
+ _setup_new_database(cur, database_engine, data_stores=data_stores)
# check if any of our configured dynamic modules want a database
if config is not None:
@@ -84,9 +98,10 @@ def prepare_database(db_conn, database_engine, config):
raise
-def _setup_new_database(cur, database_engine):
+def _setup_new_database(cur, database_engine, data_stores):
"""Sets up the database by finding a base set of "full schemas" and then
- applying any necessary deltas.
+ applying any necessary deltas, including schemas from the given data
+ stores.
The "full_schemas" directory has subdirectories named after versions. This
function searches for the highest version less than or equal to
@@ -111,51 +126,83 @@ def _setup_new_database(cur, database_engine):
In the example foo.sql and bar.sql would be run, and then any delta files
for versions strictly greater than 11.
+
+ Note: we apply the full schemas and deltas from the top level `schema/`
+ folder as well those in the data stores specified.
+
+ Args:
+ cur (Cursor): a database cursor
+ database_engine (DatabaseEngine)
+ data_stores (list[str]): The names of the data stores to instantiate
+ on the given database.
"""
- current_dir = os.path.join(dir_path, "schema", "full_schemas")
- directory_entries = os.listdir(current_dir)
- valid_dirs = []
- pattern = re.compile(r"^\d+(\.sql)?$")
+ # We're about to set up a brand new database so we check that its
+ # configured to our liking.
+ database_engine.check_new_database(cur)
- if isinstance(database_engine, PostgresEngine):
- specific = "postgres"
- else:
- specific = "sqlite"
+ current_dir = os.path.join(dir_path, "schema", "full_schemas")
+ directory_entries = os.listdir(current_dir)
- specific_pattern = re.compile(r"^\d+(\.sql." + specific + r")?$")
+ # First we find the highest full schema version we have
+ valid_versions = []
for filename in directory_entries:
- match = pattern.match(filename) or specific_pattern.match(filename)
- abs_path = os.path.join(current_dir, filename)
- if match and os.path.isdir(abs_path):
- ver = int(match.group(0))
- if ver <= SCHEMA_VERSION:
- valid_dirs.append((ver, abs_path))
- else:
- logger.warn("Unexpected entry in 'full_schemas': %s", filename)
+ try:
+ ver = int(filename)
+ except ValueError:
+ continue
- if not valid_dirs:
+ if ver <= SCHEMA_VERSION:
+ valid_versions.append(ver)
+
+ if not valid_versions:
raise PrepareDatabaseException(
"Could not find a suitable base set of full schemas"
)
- max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0])
+ max_current_ver = max(valid_versions)
logger.debug("Initialising schema v%d", max_current_ver)
- directory_entries = os.listdir(sql_dir)
+ # Now lets find all the full schema files, both in the global schema and
+ # in data store schemas.
+ directories = [os.path.join(current_dir, str(max_current_ver))]
+ directories.extend(
+ os.path.join(
+ dir_path,
+ "data_stores",
+ data_store,
+ "schema",
+ "full_schemas",
+ str(max_current_ver),
+ )
+ for data_store in data_stores
+ )
+
+ directory_entries = []
+ for directory in directories:
+ directory_entries.extend(
+ _DirectoryListing(file_name, os.path.join(directory, file_name))
+ for file_name in os.listdir(directory)
+ )
+
+ if isinstance(database_engine, PostgresEngine):
+ specific = "postgres"
+ else:
+ specific = "sqlite"
- for filename in sorted(fnmatch.filter(directory_entries, "*.sql") + fnmatch.filter(
- directory_entries, "*.sql." + specific
- )):
- sql_loc = os.path.join(sql_dir, filename)
- logger.debug("Applying schema %s", sql_loc)
- executescript(cur, sql_loc)
+ directory_entries.sort()
+ for entry in directory_entries:
+ if entry.file_name.endswith(".sql") or entry.file_name.endswith(
+ ".sql." + specific
+ ):
+ logger.debug("Applying schema %s", entry.absolute_path)
+ executescript(cur, entry.absolute_path)
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
),
(max_current_ver, False),
)
@@ -167,6 +214,7 @@ def _setup_new_database(cur, database_engine):
upgraded=False,
database_engine=database_engine,
config=None,
+ data_stores=data_stores,
is_empty=True,
)
@@ -178,6 +226,7 @@ def _upgrade_existing_database(
upgraded,
database_engine,
config,
+ data_stores,
is_empty=False,
):
"""Upgrades an existing database.
@@ -214,6 +263,10 @@ def _upgrade_existing_database(
only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in
some arbitrary order.
+ Note: we apply the delta files from the specified data stores as well as
+ those in the top-level schema. We apply all delta files across data stores
+ for a version before applying those in the next version.
+
Args:
cur (Cursor)
current_version (int): The current version of the schema.
@@ -223,7 +276,19 @@ def _upgrade_existing_database(
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
+ database_engine (DatabaseEngine)
+ config (synapse.config.homeserver.HomeServerConfig|None):
+ None if we are initialising a blank database, otherwise the application
+ config
+ data_stores (list[str]): The names of the data stores to instantiate
+ on the given database.
+ is_empty (bool): Is this a blank database? I.e. do we need to run the
+ upgrade portions of the delta scripts.
"""
+ if is_empty:
+ assert not applied_delta_files
+ else:
+ assert config
if current_version > SCHEMA_VERSION:
raise ValueError(
@@ -231,33 +296,89 @@ def _upgrade_existing_database(
+ "new for the server to understand"
)
+ # some of the deltas assume that config.server_name is set correctly, so now
+ # is a good time to run the sanity check.
+ if not is_empty and "main" in data_stores:
+ from synapse.storage.data_stores.main import check_database_before_upgrade
+
+ check_database_before_upgrade(cur, database_engine, config)
+
start_ver = current_version
if not upgraded:
start_ver += 1
logger.debug("applied_delta_files: %s", applied_delta_files)
+ if isinstance(database_engine, PostgresEngine):
+ specific_engine_extension = ".postgres"
+ else:
+ specific_engine_extension = ".sqlite"
+
+ specific_engine_extensions = (".sqlite", ".postgres")
+
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.info("Upgrading schema to v%d", v)
+ # We need to search both the global and per data store schema
+ # directories for schema updates.
+
+ # First we find the directories to search in
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
+ directories = [delta_dir]
+ for data_store in data_stores:
+ directories.append(
+ os.path.join(
+ dir_path, "data_stores", data_store, "schema", "delta", str(v)
+ )
+ )
- try:
- directory_entries = os.listdir(delta_dir)
- except OSError:
- logger.exception("Could not open delta dir for version %d", v)
- raise UpgradeDatabaseException(
- "Could not open delta dir for version %d" % (v,)
+ # Used to check if we have any duplicate file names
+ file_name_counter = Counter()
+
+ # Now find which directories have anything of interest.
+ directory_entries = []
+ for directory in directories:
+ logger.debug("Looking for schema deltas in %s", directory)
+ try:
+ file_names = os.listdir(directory)
+ directory_entries.extend(
+ _DirectoryListing(file_name, os.path.join(directory, file_name))
+ for file_name in file_names
+ )
+
+ for file_name in file_names:
+ file_name_counter[file_name] += 1
+ except FileNotFoundError:
+ # Data stores can have empty entries for a given version delta.
+ pass
+ except OSError:
+ raise UpgradeDatabaseException(
+ "Could not open delta dir for version %d: %s" % (v, directory)
+ )
+
+ duplicates = {
+ file_name for file_name, count in file_name_counter.items() if count > 1
+ }
+ if duplicates:
+ # We don't support using the same file name in the same delta version.
+ raise PrepareDatabaseException(
+ "Found multiple delta files with the same name in v%d: %s",
+ v,
+ duplicates,
)
+ # We sort to ensure that we apply the delta files in a consistent
+ # order (to avoid bugs caused by inconsistent directory listing order)
directory_entries.sort()
- for file_name in directory_entries:
+ for entry in directory_entries:
+ file_name = entry.file_name
relative_path = os.path.join(str(v), file_name)
- logger.debug("Found file: %s", relative_path)
+ absolute_path = entry.absolute_path
+
+ logger.debug("Found file: %s (%s)", relative_path, absolute_path)
if relative_path in applied_delta_files:
continue
- absolute_path = os.path.join(dir_path, "schema", "delta", relative_path)
root_name, ext = os.path.splitext(file_name)
if ext == ".py":
# This is a python upgrade module. We need to import into some
@@ -273,15 +394,22 @@ def _upgrade_existing_database(
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
- pass
+ continue
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.info("Applying schema %s", relative_path)
executescript(cur, absolute_path)
+ elif ext == specific_engine_extension and root_name.endswith(".sql"):
+ # A .sql file specific to our engine; just read and execute it
+ logger.info("Applying engine-specific schema %s", relative_path)
+ executescript(cur, absolute_path)
+ elif ext in specific_engine_extensions and root_name.endswith(".sql"):
+ # A .sql file for a different engine; skip it.
+ continue
else:
# Not a valid delta file.
- logger.warn(
- "Found directory entry that did not end in .py or" " .sql: %s",
+ logger.warning(
+ "Found directory entry that did not end in .py or .sql: %s",
relative_path,
)
continue
@@ -289,7 +417,7 @@ def _upgrade_existing_database(
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)"
+ "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
),
(v, relative_path),
)
@@ -297,7 +425,7 @@ def _upgrade_existing_database(
cur.execute("DELETE FROM schema_version")
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
),
(v, True),
)
@@ -313,7 +441,7 @@ def _apply_module_schemas(txn, database_engine, config):
application config
"""
for (mod, _config) in config.password_providers:
- if not hasattr(mod, 'get_db_schema_files'):
+ if not hasattr(mod, "get_db_schema_files"):
continue
modname = ".".join((mod.__module__, mod.__name__))
_apply_module_schema_files(
@@ -337,13 +465,13 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
),
(modname,),
)
- applied_deltas = set(d for d, in cur)
+ applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams:
if name in applied_deltas:
continue
root_name, ext = os.path.splitext(name)
- if ext != '.sql':
+ if ext != ".sql":
raise PrepareDatabaseException(
"only .sql files are currently supported for module schemas"
)
@@ -355,7 +483,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
+ "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
),
(modname, name),
)
@@ -407,7 +535,7 @@ def get_statements(f):
def executescript(txn, schema_path):
- with open(schema_path, 'r') as f:
+ with open(schema_path, "r") as f:
for statement in get_statements(f):
txn.execute(statement)
@@ -433,3 +561,16 @@ def _get_or_create_schema_state(txn, database_engine):
return current_version, applied_deltas, upgraded
return None
+
+
+@attr.s()
+class _DirectoryListing(object):
+ """Helper class to store schema file name and the
+ absolute path to it.
+
+ These entries get sorted, so for consistency we want to ensure that
+ `file_name` attr is kept first.
+ """
+
+ file_name = attr.ib()
+ absolute_path = attr.ib()
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 42ec8c6bb8..18a462f0ee 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -15,13 +15,7 @@
from collections import namedtuple
-from twisted.internet import defer
-
from synapse.api.constants import PresenceState
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedList
-
-from ._base import SQLBaseStore
class UserPresenceState(
@@ -73,135 +67,3 @@ class UserPresenceState(
status_msg=None,
currently_active=False,
)
-
-
-class PresenceStore(SQLBaseStore):
- @defer.inlineCallbacks
- def update_presence(self, presence_states):
- stream_ordering_manager = self._presence_id_gen.get_next_mult(
- len(presence_states)
- )
-
- with stream_ordering_manager as stream_orderings:
- yield self.runInteraction(
- "update_presence",
- self._update_presence_txn,
- stream_orderings,
- presence_states,
- )
-
- defer.returnValue(
- (stream_orderings[-1], self._presence_id_gen.get_current_token())
- )
-
- def _update_presence_txn(self, txn, stream_orderings, presence_states):
- for stream_id, state in zip(stream_orderings, presence_states):
- txn.call_after(
- self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
- )
- txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
-
- # Actually insert new rows
- self._simple_insert_many_txn(
- txn,
- table="presence_stream",
- values=[
- {
- "stream_id": stream_id,
- "user_id": state.user_id,
- "state": state.state,
- "last_active_ts": state.last_active_ts,
- "last_federation_update_ts": state.last_federation_update_ts,
- "last_user_sync_ts": state.last_user_sync_ts,
- "status_msg": state.status_msg,
- "currently_active": state.currently_active,
- }
- for state in presence_states
- ],
- )
-
- # Delete old rows to stop database from getting really big
- sql = (
- "DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
- )
-
- for states in batch_iter(presence_states, 50):
- args = [stream_id]
- args.extend(s.user_id for s in states)
- txn.execute(sql % (",".join("?" for _ in states),), args)
-
- def get_all_presence_updates(self, last_id, current_id):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_presence_updates_txn(txn):
- sql = (
- "SELECT stream_id, user_id, state, last_active_ts,"
- " last_federation_update_ts, last_user_sync_ts, status_msg,"
- " currently_active"
- " FROM presence_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- )
- txn.execute(sql, (last_id, current_id))
- return txn.fetchall()
-
- return self.runInteraction(
- "get_all_presence_updates", get_all_presence_updates_txn
- )
-
- @cached()
- def _get_presence_for_user(self, user_id):
- raise NotImplementedError()
-
- @cachedList(
- cached_method_name="_get_presence_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
- )
- def get_presence_for_users(self, user_ids):
- rows = yield self._simple_select_many_batch(
- table="presence_stream",
- column="user_id",
- iterable=user_ids,
- keyvalues={},
- retcols=(
- "user_id",
- "state",
- "last_active_ts",
- "last_federation_update_ts",
- "last_user_sync_ts",
- "status_msg",
- "currently_active",
- ),
- desc="get_presence_for_users",
- )
-
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows})
-
- def get_current_presence_token(self):
- return self._presence_id_gen.get_current_token()
-
- def allow_presence_visible(self, observed_localpart, observer_userid):
- return self._simple_insert(
- table="presence_allow_inbound",
- values={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="allow_presence_visible",
- or_ignore=True,
- )
-
- def disallow_presence_visible(self, observed_localpart, observer_userid):
- return self._simple_delete_one(
- table="presence_allow_inbound",
- keyvalues={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="disallow_presence_visible",
- )
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
new file mode 100644
index 0000000000..fdc0abf5cf
--- /dev/null
+++ b/synapse/storage/purge_events.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+
+from twisted.internet import defer
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeEventsStorage(object):
+ """High level interface for purging rooms and event history.
+ """
+
+ def __init__(self, hs, stores):
+ self.stores = stores
+
+ @defer.inlineCallbacks
+ def purge_room(self, room_id: str):
+ """Deletes all record of a room
+ """
+
+ state_groups_to_delete = yield self.stores.main.purge_room(room_id)
+ yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
+
+ @defer.inlineCallbacks
+ def purge_history(self, room_id, token, delete_local_events):
+ """Deletes room history before a certain point
+
+ Args:
+ room_id (str):
+
+ token (str): A topological token to delete events before
+
+ delete_local_events (bool):
+ if True, we will delete local events as well as remote ones
+ (instead of just marking them as outliers and deleting their
+ state groups).
+ """
+ state_groups = yield self.stores.main.purge_history(
+ room_id, token, delete_local_events
+ )
+
+ logger.info("[purge] finding state groups that can be deleted")
+
+ sg_to_delete = yield self._find_unreferenced_groups(state_groups)
+
+ yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
+
+ @defer.inlineCallbacks
+ def _find_unreferenced_groups(self, state_groups):
+ """Used when purging history to figure out which state groups can be
+ deleted.
+
+ Args:
+ state_groups (set[int]): Set of state groups referenced by events
+ that are going to be deleted.
+
+ Returns:
+ Deferred[set[int]] The set of state groups that can be deleted.
+ """
+ # Graph of state group -> previous group
+ graph = {}
+
+ # Set of events that we have found to be referenced by events
+ referenced_groups = set()
+
+ # Set of state groups we've already seen
+ state_groups_seen = set(state_groups)
+
+ # Set of state groups to handle next.
+ next_to_search = set(state_groups)
+ while next_to_search:
+ # We bound size of groups we're looking up at once, to stop the
+ # SQL query getting too big
+ if len(next_to_search) < 100:
+ current_search = next_to_search
+ next_to_search = set()
+ else:
+ current_search = set(itertools.islice(next_to_search, 100))
+ next_to_search -= current_search
+
+ referenced = yield self.stores.main.get_referenced_state_groups(
+ current_search
+ )
+ referenced_groups |= referenced
+
+ # We don't continue iterating up the state group graphs for state
+ # groups that are referenced.
+ current_search -= referenced
+
+ edges = yield self.stores.state.get_previous_state_groups(current_search)
+
+ prevs = set(edges.values())
+ # We don't bother re-handling groups we've already seen
+ prevs -= state_groups_seen
+ next_to_search |= prevs
+ state_groups_seen |= prevs
+
+ graph.update(edges)
+
+ to_delete = state_groups_seen - referenced_groups
+
+ return to_delete
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 9e406baafa..f47cec0d86 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -14,710 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
-import logging
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.push.baserules import list_with_base_rules
-from synapse.storage.appservice import ApplicationServiceWorkerStore
-from synapse.storage.pusher import PusherWorkerStore
-from synapse.storage.receipts import ReceiptsWorkerStore
-from synapse.storage.roommember import RoomMemberWorkerStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-from ._base import SQLBaseStore
-
-logger = logging.getLogger(__name__)
-
-
-def _load_rules(rawrules, enabled_map):
- ruleslist = []
- for rawrule in rawrules:
- rule = dict(rawrule)
- rule["conditions"] = json.loads(rawrule["conditions"])
- rule["actions"] = json.loads(rawrule["actions"])
- ruleslist.append(rule)
-
- # We're going to be mutating this a lot, so do a deep copy
- rules = list(list_with_base_rules(ruleslist))
-
- for i, rule in enumerate(rules):
- rule_id = rule['rule_id']
- if rule_id in enabled_map:
- if rule.get('enabled', True) != bool(enabled_map[rule_id]):
- # Rules are cached across users.
- rule = dict(rule)
- rule['enabled'] = bool(enabled_map[rule_id])
- rules[i] = rule
-
- return rules
-
-
-class PushRulesWorkerStore(
- ApplicationServiceWorkerStore,
- ReceiptsWorkerStore,
- PusherWorkerStore,
- RoomMemberWorkerStore,
- SQLBaseStore,
-):
- """This is an abstract base class where subclasses must implement
- `get_max_push_rules_stream_id` which can be called in the initializer.
- """
-
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
- def __init__(self, db_conn, hs):
- super(PushRulesWorkerStore, self).__init__(db_conn, hs)
-
- push_rules_prefill, push_rules_id = self._get_cache_dict(
- db_conn,
- "push_rules_stream",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self.get_max_push_rules_stream_id(),
- )
-
- self.push_rules_stream_cache = StreamChangeCache(
- "PushRulesStreamChangeCache",
- push_rules_id,
- prefilled_cache=push_rules_prefill,
- )
-
- @abc.abstractmethod
- def get_max_push_rules_stream_id(self):
- """Get the position of the push rules stream.
-
- Returns:
- int
- """
- raise NotImplementedError()
-
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_for_user(self, user_id):
- rows = yield self._simple_select_list(
- table="push_rules",
- keyvalues={"user_name": user_id},
- retcols=(
- "user_name",
- "rule_id",
- "priority_class",
- "priority",
- "conditions",
- "actions",
- ),
- desc="get_push_rules_enabled_for_user",
- )
-
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
-
- enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
-
- rules = _load_rules(rows, enabled_map)
-
- defer.returnValue(rules)
-
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_enabled_for_user(self, user_id):
- results = yield self._simple_select_list(
- table="push_rules_enable",
- keyvalues={'user_name': user_id},
- retcols=("user_name", "rule_id", "enabled"),
- desc="get_push_rules_enabled_for_user",
- )
- defer.returnValue(
- {r['rule_id']: False if r['enabled'] == 0 else True for r in results}
- )
-
- def have_push_rules_changed_for_user(self, user_id, last_id):
- if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
- else:
-
- def have_push_rules_changed_txn(txn):
- sql = (
- "SELECT COUNT(stream_id) FROM push_rules_stream"
- " WHERE user_id = ? AND ? < stream_id"
- )
- txn.execute(sql, (user_id, last_id))
- count, = txn.fetchone()
- return bool(count)
-
- return self.runInteraction(
- "have_push_rules_changed", have_push_rules_changed_txn
- )
-
- @cachedList(
- cached_method_name="get_push_rules_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
- )
- def bulk_get_push_rules(self, user_ids):
- if not user_ids:
- defer.returnValue({})
-
- results = {user_id: [] for user_id in user_ids}
-
- rows = yield self._simple_select_many_batch(
- table="push_rules",
- column="user_name",
- iterable=user_ids,
- retcols=("*",),
- desc="bulk_get_push_rules",
- )
-
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
-
- for row in rows:
- results.setdefault(row['user_name'], []).append(row)
-
- enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
-
- for user_id, rules in results.items():
- results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
-
- defer.returnValue(results)
-
- @defer.inlineCallbacks
- def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
- """Move a single push rule from one room to another for a specific user.
-
- Args:
- new_room_id (str): ID of the new room.
- user_id (str): ID of user the push rule belongs to.
- rule (Dict): A push rule.
- """
- # Create new rule id
- rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1])
- new_rule_id = rule_id_scope + "/" + new_room_id
-
- # Change room id in each condition
- for condition in rule.get("conditions", []):
- if condition.get("key") == "room_id":
- condition["pattern"] = new_room_id
-
- # Add the rule for the new room
- yield self.add_push_rule(
- user_id=user_id,
- rule_id=new_rule_id,
- priority_class=rule["priority_class"],
- conditions=rule["conditions"],
- actions=rule["actions"],
- )
-
- # Delete push rule for the old room
- yield self.delete_push_rule(user_id, rule["rule_id"])
-
- @defer.inlineCallbacks
- def move_push_rules_from_room_to_room_for_user(
- self, old_room_id, new_room_id, user_id
- ):
- """Move all of the push rules from one room to another for a specific
- user.
-
- Args:
- old_room_id (str): ID of the old room.
- new_room_id (str): ID of the new room.
- user_id (str): ID of user to copy push rules for.
- """
- # Retrieve push rules for this user
- user_push_rules = yield self.get_push_rules_for_user(user_id)
-
- # Get rules relating to the old room, move them to the new room, then
- # delete them from the old room
- for rule in user_push_rules:
- conditions = rule.get("conditions", [])
- if any(
- (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
- for c in conditions
- ):
- self.move_push_rule_from_room_to_room(new_room_id, user_id, rule)
-
- @defer.inlineCallbacks
- def bulk_get_push_rules_for_room(self, event, context):
- state_group = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- current_state_ids = yield context.get_current_state_ids(self)
- result = yield self._bulk_get_push_rules_for_room(
- event.room_id, state_group, current_state_ids, event=event
- )
- defer.returnValue(result)
-
- @cachedInlineCallbacks(num_args=2, cache_context=True)
- def _bulk_get_push_rules_for_room(
- self, room_id, state_group, current_state_ids, cache_context, event=None
- ):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
-
- # We also will want to generate notifs for other people in the room so
- # their unread countss are correct in the event stream, but to avoid
- # generating them for bot / AS users etc, we only do so for people who've
- # sent a read receipt into the room.
-
- users_in_room = yield self._get_joined_users_from_context(
- room_id,
- state_group,
- current_state_ids,
- on_invalidate=cache_context.invalidate,
- event=event,
- )
-
- # We ignore app service users for now. This is so that we don't fill
- # up the `get_if_users_have_pushers` cache with AS entries that we
- # know don't have pushers, nor even read receipts.
- local_users_in_room = set(
- u
- for u in users_in_room
- if self.hs.is_mine_id(u)
- and not self.get_if_app_services_interested_in_user(u)
- )
-
- # users in the room who have pushers need to get push rules run because
- # that's how their pushers work
- if_users_with_pushers = yield self.get_if_users_have_pushers(
- local_users_in_room, on_invalidate=cache_context.invalidate
- )
- user_ids = set(
- uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
- )
-
- users_with_receipts = yield self.get_users_with_read_receipts_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
-
- # any users with pushers must be ours: they have pushers
- for uid in users_with_receipts:
- if uid in local_users_in_room:
- user_ids.add(uid)
-
- rules_by_user = yield self.bulk_get_push_rules(
- user_ids, on_invalidate=cache_context.invalidate
- )
-
- rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
-
- defer.returnValue(rules_by_user)
-
- @cachedList(
- cached_method_name="get_push_rules_enabled_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
- )
- def bulk_get_push_rules_enabled(self, user_ids):
- if not user_ids:
- defer.returnValue({})
-
- results = {user_id: {} for user_id in user_ids}
-
- rows = yield self._simple_select_many_batch(
- table="push_rules_enable",
- column="user_name",
- iterable=user_ids,
- retcols=("user_name", "rule_id", "enabled"),
- desc="bulk_get_push_rules_enabled",
- )
- for row in rows:
- enabled = bool(row['enabled'])
- results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
- defer.returnValue(results)
-
-
-class PushRuleStore(PushRulesWorkerStore):
- @defer.inlineCallbacks
- def add_push_rule(
- self,
- user_id,
- rule_id,
- priority_class,
- conditions,
- actions,
- before=None,
- after=None,
- ):
- conditions_json = json.dumps(conditions)
- actions_json = json.dumps(actions)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- if before or after:
- yield self.runInteraction(
- "_add_push_rule_relative_txn",
- self._add_push_rule_relative_txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- before,
- after,
- )
- else:
- yield self.runInteraction(
- "_add_push_rule_highest_priority_txn",
- self._add_push_rule_highest_priority_txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- )
-
- def _add_push_rule_relative_txn(
- self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- before,
- after,
- ):
- # Lock the table since otherwise we'll have annoying races between the
- # SELECT here and the UPSERT below.
- self.database_engine.lock_table(txn, "push_rules")
-
- relative_to_rule = before or after
-
- res = self._simple_select_one_txn(
- txn,
- table="push_rules",
- keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
- retcols=["priority_class", "priority"],
- allow_none=True,
- )
-
- if not res:
- raise RuleNotFoundException(
- "before/after rule not found: %s" % (relative_to_rule,)
- )
-
- base_priority_class = res["priority_class"]
- base_rule_priority = res["priority"]
-
- if base_priority_class != priority_class:
- raise InconsistentRuleException(
- "Given priority class does not match class of relative rule"
- )
-
- if before:
- # Higher priority rules are executed first, So adding a rule before
- # a rule means giving it a higher priority than that rule.
- new_rule_priority = base_rule_priority + 1
- else:
- # We increment the priority of the existing rules to make space for
- # the new rule. Therefore if we want this rule to appear after
- # an existing rule we give it the priority of the existing rule,
- # and then increment the priority of the existing rule.
- new_rule_priority = base_rule_priority
-
- sql = (
- "UPDATE push_rules SET priority = priority + 1"
- " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
- )
-
- txn.execute(sql, (user_id, priority_class, new_rule_priority))
-
- self._upsert_push_rule_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- new_rule_priority,
- conditions_json,
- actions_json,
- )
-
- def _add_push_rule_highest_priority_txn(
- self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- ):
- # Lock the table since otherwise we'll have annoying races between the
- # SELECT here and the UPSERT below.
- self.database_engine.lock_table(txn, "push_rules")
-
- # find the highest priority rule in that class
- sql = (
- "SELECT COUNT(*), MAX(priority) FROM push_rules"
- " WHERE user_name = ? and priority_class = ?"
- )
- txn.execute(sql, (user_id, priority_class))
- res = txn.fetchall()
- (how_many, highest_prio) = res[0]
-
- new_prio = 0
- if how_many > 0:
- new_prio = highest_prio + 1
-
- self._upsert_push_rule_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- new_prio,
- conditions_json,
- actions_json,
- )
-
- def _upsert_push_rule_txn(
- self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- priority,
- conditions_json,
- actions_json,
- update_stream=True,
- ):
- """Specialised version of _simple_upsert_txn that picks a push_rule_id
- using the _push_rule_id_gen if it needs to insert the rule. It assumes
- that the "push_rules" table is locked"""
-
- sql = (
- "UPDATE push_rules"
- " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
- " WHERE user_name = ? AND rule_id = ?"
- )
-
- txn.execute(
- sql,
- (priority_class, priority, conditions_json, actions_json, user_id, rule_id),
- )
-
- if txn.rowcount == 0:
- # We didn't update a row with the given rule_id so insert one
- push_rule_id = self._push_rule_id_gen.get_next()
-
- self._simple_insert_txn(
- txn,
- table="push_rules",
- values={
- "id": push_rule_id,
- "user_name": user_id,
- "rule_id": rule_id,
- "priority_class": priority_class,
- "priority": priority,
- "conditions": conditions_json,
- "actions": actions_json,
- },
- )
-
- if update_stream:
- self._insert_push_rules_update_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- op="ADD",
- data={
- "priority_class": priority_class,
- "priority": priority,
- "conditions": conditions_json,
- "actions": actions_json,
- },
- )
-
- @defer.inlineCallbacks
- def delete_push_rule(self, user_id, rule_id):
- """
- Delete a push rule. Args specify the row to be deleted and can be
- any of the columns in the push_rule table, but below are the
- standard ones
-
- Args:
- user_id (str): The matrix ID of the push rule owner
- rule_id (str): The rule_id of the rule to be deleted
- """
-
- def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
- self._simple_delete_one_txn(
- txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id}
- )
-
- self._insert_push_rules_update_txn(
- txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
- )
-
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.runInteraction(
- "delete_push_rule",
- delete_push_rule_txn,
- stream_id,
- event_stream_ordering,
- )
-
- @defer.inlineCallbacks
- def set_push_rule_enabled(self, user_id, rule_id, enabled):
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.runInteraction(
- "_set_push_rule_enabled_txn",
- self._set_push_rule_enabled_txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- enabled,
- )
-
- def _set_push_rule_enabled_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
- ):
- new_id = self._push_rules_enable_id_gen.get_next()
- self._simple_upsert_txn(
- txn,
- "push_rules_enable",
- {'user_name': user_id, 'rule_id': rule_id},
- {'enabled': 1 if enabled else 0},
- {'id': new_id},
- )
-
- self._insert_push_rules_update_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- op="ENABLE" if enabled else "DISABLE",
- )
-
- @defer.inlineCallbacks
- def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
- actions_json = json.dumps(actions)
-
- def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
- if is_default_rule:
- # Add a dummy rule to the rules table with the user specified
- # actions.
- priority_class = -1
- priority = 1
- self._upsert_push_rule_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- priority,
- "[]",
- actions_json,
- update_stream=False,
- )
- else:
- self._simple_update_one_txn(
- txn,
- "push_rules",
- {'user_name': user_id, 'rule_id': rule_id},
- {'actions': actions_json},
- )
-
- self._insert_push_rules_update_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- op="ACTIONS",
- data={"actions": actions_json},
- )
-
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.runInteraction(
- "set_push_rule_actions",
- set_push_rule_actions_txn,
- stream_id,
- event_stream_ordering,
- )
-
- def _insert_push_rules_update_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
- ):
- values = {
- "stream_id": stream_id,
- "event_stream_ordering": event_stream_ordering,
- "user_id": user_id,
- "rule_id": rule_id,
- "op": op,
- }
- if data is not None:
- values.update(data)
-
- self._simple_insert_txn(txn, "push_rules_stream", values=values)
-
- txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
- txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
- txn.call_after(
- self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
- )
-
- def get_all_push_rule_updates(self, last_id, current_id, limit):
- """Get all the push rules changes that have happend on the server"""
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_push_rule_updates_txn(txn):
- sql = (
- "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
- " op, priority_class, priority, conditions, actions"
- " FROM push_rules_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
-
- return self.runInteraction(
- "get_all_push_rule_updates", get_all_push_rule_updates_txn
- )
-
- def get_push_rules_stream_token(self):
- """Get the position of the push rules stream.
- Returns a pair of a stream id for the push_rules stream and the
- room stream ordering it corresponds to."""
- return self._push_rules_stream_id_gen.get_current_token()
-
- def get_max_push_rules_stream_id(self):
- return self.get_push_rules_stream_token()[0]
-
class RuleNotFoundException(Exception):
pass
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 4c83800cca..d471ec9860 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -17,13 +17,7 @@ import logging
import attr
-from twisted.internet import defer
-
-from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.stream import generate_pagination_where_clause
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -60,7 +54,7 @@ class PaginationChunk(object):
class RelationPaginationToken(object):
"""Pagination token for relation pagination API.
- As the results are order by topological ordering, we can use the
+ As the results are in topological order, we can use the
`topological_ordering` and `stream_ordering` fields of the events at the
boundaries of the chunk as pagination tokens.
@@ -115,362 +109,3 @@ class AggregationPaginationToken(object):
def as_tuple(self):
return attr.astuple(self)
-
-
-class RelationsWorkerStore(SQLBaseStore):
- @cached(tree=True)
- def get_relations_for_event(
- self,
- event_id,
- relation_type=None,
- event_type=None,
- aggregation_key=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
- """Get a list of relations for an event, ordered by topological ordering.
-
- Args:
- event_id (str): Fetch events that relate to this event ID.
- relation_type (str|None): Only fetch events with this relation
- type, if given.
- event_type (str|None): Only fetch events with this event type, if
- given.
- aggregation_key (str|None): Only fetch events with this aggregation
- key, if given.
- limit (int): Only fetch the most recent `limit` events.
- direction (str): Whether to fetch the most recent first (`"b"`) or
- the oldest first (`"f"`).
- from_token (RelationPaginationToken|None): Fetch rows from the given
- token, or from the start if None.
- to_token (RelationPaginationToken|None): Fetch rows up to the given
- token, or up to the end if None.
-
- Returns:
- Deferred[PaginationChunk]: List of event IDs that match relations
- requested. The rows are of the form `{"event_id": "..."}`.
- """
-
- where_clause = ["relates_to_id = ?"]
- where_args = [event_id]
-
- if relation_type is not None:
- where_clause.append("relation_type = ?")
- where_args.append(relation_type)
-
- if event_type is not None:
- where_clause.append("type = ?")
- where_args.append(event_type)
-
- if aggregation_key:
- where_clause.append("aggregation_key = ?")
- where_args.append(aggregation_key)
-
- pagination_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("topological_ordering", "stream_ordering"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
- engine=self.database_engine,
- )
-
- if pagination_clause:
- where_clause.append(pagination_clause)
-
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
- sql = """
- SELECT event_id, topological_ordering, stream_ordering
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE %s
- ORDER BY topological_ordering %s, stream_ordering %s
- LIMIT ?
- """ % (
- " AND ".join(where_clause),
- order,
- order,
- )
-
- def _get_recent_references_for_event_txn(txn):
- txn.execute(sql, where_args + [limit + 1])
-
- last_topo_id = None
- last_stream_id = None
- events = []
- for row in txn:
- events.append({"event_id": row[0]})
- last_topo_id = row[1]
- last_stream_id = row[2]
-
- next_batch = None
- if len(events) > limit and last_topo_id and last_stream_id:
- next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
-
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
-
- return self.runInteraction(
- "get_recent_references_for_event", _get_recent_references_for_event_txn
- )
-
- @cached(tree=True)
- def get_aggregation_groups_for_event(
- self,
- event_id,
- event_type=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
- """Get a list of annotations on the event, grouped by event type and
- aggregation key, sorted by count.
-
- This is used e.g. to get the what and how many reactions have happend
- on an event.
-
- Args:
- event_id (str): Fetch events that relate to this event ID.
- event_type (str|None): Only fetch events with this event type, if
- given.
- limit (int): Only fetch the `limit` groups.
- direction (str): Whether to fetch the highest count first (`"b"`) or
- the lowest count first (`"f"`).
- from_token (AggregationPaginationToken|None): Fetch rows from the
- given token, or from the start if None.
- to_token (AggregationPaginationToken|None): Fetch rows up to the
- given token, or up to the end if None.
-
-
- Returns:
- Deferred[PaginationChunk]: List of groups of annotations that
- match. Each row is a dict with `type`, `key` and `count` fields.
- """
-
- where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args = [event_id, RelationTypes.ANNOTATION]
-
- if event_type:
- where_clause.append("type = ?")
- where_args.append(event_type)
-
- having_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
- engine=self.database_engine,
- )
-
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
- if having_clause:
- having_clause = "HAVING " + having_clause
- else:
- having_clause = ""
-
- sql = """
- SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE {where_clause}
- GROUP BY relation_type, type, aggregation_key
- {having_clause}
- ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
- LIMIT ?
- """.format(
- where_clause=" AND ".join(where_clause),
- order=order,
- having_clause=having_clause,
- )
-
- def _get_aggregation_groups_for_event_txn(txn):
- txn.execute(sql, where_args + [limit + 1])
-
- next_batch = None
- events = []
- for row in txn:
- events.append({"type": row[0], "key": row[1], "count": row[2]})
- next_batch = AggregationPaginationToken(row[2], row[3])
-
- if len(events) <= limit:
- next_batch = None
-
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
-
- return self.runInteraction(
- "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
- )
-
- @cachedInlineCallbacks()
- def get_applicable_edit(self, event_id):
- """Get the most recent edit (if any) that has happened for the given
- event.
-
- Correctly handles checking whether edits were allowed to happen.
-
- Args:
- event_id (str): The original event ID
-
- Returns:
- Deferred[EventBase|None]: Returns the most recent edit, if any.
- """
-
- # We only allow edits for `m.room.message` events that have the same sender
- # and event type. We can't assert these things during regular event auth so
- # we have to do the checks post hoc.
-
- # Fetches latest edit that has the same type and sender as the
- # original, and is an `m.room.message`.
- sql = """
- SELECT edit.event_id FROM events AS edit
- INNER JOIN event_relations USING (event_id)
- INNER JOIN events AS original ON
- original.event_id = relates_to_id
- AND edit.type = original.type
- AND edit.sender = original.sender
- WHERE
- relates_to_id = ?
- AND relation_type = ?
- AND edit.type = 'm.room.message'
- ORDER by edit.origin_server_ts DESC, edit.event_id DESC
- LIMIT 1
- """
-
- def _get_applicable_edit_txn(txn):
- txn.execute(sql, (event_id, RelationTypes.REPLACE))
- row = txn.fetchone()
- if row:
- return row[0]
-
- edit_id = yield self.runInteraction(
- "get_applicable_edit", _get_applicable_edit_txn
- )
-
- if not edit_id:
- return
-
- edit_event = yield self.get_event(edit_id, allow_none=True)
- defer.returnValue(edit_event)
-
- def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
- """Check if a user has already annotated an event with the same key
- (e.g. already liked an event).
-
- Args:
- parent_id (str): The event being annotated
- event_type (str): The event type of the annotation
- aggregation_key (str): The aggregation key of the annotation
- sender (str): The sender of the annotation
-
- Returns:
- Deferred[bool]
- """
-
- sql = """
- SELECT 1 FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE
- relates_to_id = ?
- AND relation_type = ?
- AND type = ?
- AND sender = ?
- AND aggregation_key = ?
- LIMIT 1;
- """
-
- def _get_if_user_has_annotated_event(txn):
- txn.execute(
- sql,
- (
- parent_id,
- RelationTypes.ANNOTATION,
- event_type,
- sender,
- aggregation_key,
- ),
- )
-
- return bool(txn.fetchone())
-
- return self.runInteraction(
- "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
- )
-
-
-class RelationsStore(RelationsWorkerStore):
- def _handle_event_relations(self, txn, event):
- """Handles inserting relation data during peristence of events
-
- Args:
- txn
- event (EventBase)
- """
- relation = event.content.get("m.relates_to")
- if not relation:
- # No relations
- return
-
- rel_type = relation.get("rel_type")
- if rel_type not in (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.REPLACE,
- ):
- # Unknown relation type
- return
-
- parent_id = relation.get("event_id")
- if not parent_id:
- # Invalid relation
- return
-
- aggregation_key = relation.get("key")
-
- self._simple_insert_txn(
- txn,
- table="event_relations",
- values={
- "event_id": event.event_id,
- "relates_to_id": parent_id,
- "relation_type": rel_type,
- "aggregation_key": aggregation_key,
- },
- )
-
- txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
- txn.call_after(
- self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
- )
-
- if rel_type == RelationTypes.REPLACE:
- txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
-
- def _handle_redaction(self, txn, redacted_event_id):
- """Handles receiving a redaction and checking whether we need to remove
- any redacted relations from the database.
-
- Args:
- txn
- redacted_event_id (str): The event that was redacted.
- """
-
- self._simple_delete_txn(
- txn,
- table="event_relations",
- keyvalues={
- "event_id": redacted_event_id,
- }
- )
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
deleted file mode 100644
index db3d052d33..0000000000
--- a/synapse/storage/room.py
+++ /dev/null
@@ -1,900 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import collections
-import logging
-import re
-
-from six import integer_types
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes
-from synapse.api.errors import StoreError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.search import SearchStore
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-
-logger = logging.getLogger(__name__)
-
-
-OpsLevel = collections.namedtuple(
- "OpsLevel", ("ban_level", "kick_level", "redact_level")
-)
-
-RatelimitOverride = collections.namedtuple(
- "RatelimitOverride", ("messages_per_second", "burst_count")
-)
-
-
-class RoomWorkerStore(SQLBaseStore):
- def get_room(self, room_id):
- """Retrieve a room.
-
- Args:
- room_id (str): The ID of the room to retrieve.
- Returns:
- A dict containing the room information, or None if the room is unknown.
- """
- return self._simple_select_one(
- table="rooms",
- keyvalues={"room_id": room_id},
- retcols=("room_id", "is_public", "creator"),
- desc="get_room",
- allow_none=True,
- )
-
- def get_public_room_ids(self):
- return self._simple_select_onecol(
- table="rooms",
- keyvalues={"is_public": True},
- retcol="room_id",
- desc="get_public_room_ids",
- )
-
- @cached(num_args=2, max_entries=100)
- def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
- """Get pulbic rooms for a particular list, or across all lists.
-
- Args:
- stream_id (int)
- network_tuple (ThirdPartyInstanceID): The list to use (None, None)
- means the main list, None means all lsits.
- """
- return self.runInteraction(
- "get_public_room_ids_at_stream_id",
- self.get_public_room_ids_at_stream_id_txn,
- stream_id,
- network_tuple=network_tuple,
- )
-
- def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
- return {
- rm
- for rm, vis in self.get_published_at_stream_id_txn(
- txn, stream_id, network_tuple=network_tuple
- ).items()
- if vis
- }
-
- def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
- if network_tuple:
- # We want to get from a particular list. No aggregation required.
-
- sql = """
- SELECT room_id, visibility FROM public_room_list_stream
- INNER JOIN (
- SELECT room_id, max(stream_id) AS stream_id
- FROM public_room_list_stream
- WHERE stream_id <= ? %s
- GROUP BY room_id
- ) grouped USING (room_id, stream_id)
- """
-
- if network_tuple.appservice_id is not None:
- txn.execute(
- sql % ("AND appservice_id = ? AND network_id = ?",),
- (stream_id, network_tuple.appservice_id, network_tuple.network_id),
- )
- else:
- txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
- return dict(txn)
- else:
- # We want to get from all lists, so we need to aggregate the results
-
- logger.info("Executing full list")
-
- sql = """
- SELECT room_id, visibility
- FROM public_room_list_stream
- INNER JOIN (
- SELECT
- room_id, max(stream_id) AS stream_id, appservice_id,
- network_id
- FROM public_room_list_stream
- WHERE stream_id <= ?
- GROUP BY room_id, appservice_id, network_id
- ) grouped USING (room_id, stream_id)
- """
-
- txn.execute(sql, (stream_id,))
-
- results = {}
- # A room is visible if its visible on any list.
- for room_id, visibility in txn:
- results[room_id] = bool(visibility) or results.get(room_id, False)
-
- return results
-
- def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
- def get_public_room_changes_txn(txn):
- then_rooms = self.get_public_room_ids_at_stream_id_txn(
- txn, prev_stream_id, network_tuple
- )
-
- now_rooms_dict = self.get_published_at_stream_id_txn(
- txn, new_stream_id, network_tuple
- )
-
- now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
- now_rooms_not_visible = set(
- rm for rm, vis in now_rooms_dict.items() if not vis
- )
-
- newly_visible = now_rooms_visible - then_rooms
- newly_unpublished = now_rooms_not_visible & then_rooms
-
- return newly_visible, newly_unpublished
-
- return self.runInteraction(
- "get_public_room_changes", get_public_room_changes_txn
- )
-
- @cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self._simple_select_one_onecol(
- table="blocked_rooms",
- keyvalues={"room_id": room_id},
- retcol="1",
- allow_none=True,
- desc="is_room_blocked",
- )
-
- @defer.inlineCallbacks
- def is_room_published(self, room_id):
- """Check whether a room has been published in the local public room
- directory.
-
- Args:
- room_id (str)
- Returns:
- bool: Whether the room is currently published in the room directory
- """
- # Get room information
- room_info = yield self.get_room(room_id)
- if not room_info:
- defer.returnValue(False)
-
- # Check the is_public value
- defer.returnValue(room_info.get("is_public", False))
-
- @cachedInlineCallbacks(max_entries=10000)
- def get_ratelimit_for_user(self, user_id):
- """Check if there are any overrides for ratelimiting for the given
- user
-
- Args:
- user_id (str)
-
- Returns:
- RatelimitOverride if there is an override, else None. If the contents
- of RatelimitOverride are None or 0 then ratelimitng has been
- disabled for that user entirely.
- """
- row = yield self._simple_select_one(
- table="ratelimit_override",
- keyvalues={"user_id": user_id},
- retcols=("messages_per_second", "burst_count"),
- allow_none=True,
- desc="get_ratelimit_for_user",
- )
-
- if row:
- defer.returnValue(
- RatelimitOverride(
- messages_per_second=row["messages_per_second"],
- burst_count=row["burst_count"],
- )
- )
- else:
- defer.returnValue(None)
-
- @cachedInlineCallbacks()
- def get_retention_policy_for_room(self, room_id):
- """Get the retention policy for a given room.
-
- If no retention policy has been found for this room, returns a policy defined
- by the configured default policy (which has None as both the 'min_lifetime' and
- the 'max_lifetime' if no default policy has been defined in the server's
- configuration).
-
- Args:
- room_id (str): The ID of the room to get the retention policy of.
-
- Returns:
- dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
- """
- # If the room retention feature is disabled, return a policy with no minimum nor
- # maximum, in order not to filter out events we should filter out when sending to
- # the client.
- if not self.config.retention_enabled:
- defer.returnValue({
- "min_lifetime": None,
- "max_lifetime": None,
- })
-
- def get_retention_policy_for_room_txn(txn):
- txn.execute(
- """
- SELECT min_lifetime, max_lifetime FROM room_retention
- INNER JOIN current_state_events USING (event_id, room_id)
- WHERE room_id = ?;
- """,
- (room_id,)
- )
-
- return self.cursor_to_dict(txn)
-
- ret = yield self.runInteraction(
- "get_retention_policy_for_room",
- get_retention_policy_for_room_txn,
- )
-
- # If we don't know this room ID, ret will be None, in this case return the default
- # policy.
- if not ret:
- defer.returnValue({
- "min_lifetime": self.config.retention_default_min_lifetime,
- "max_lifetime": self.config.retention_default_max_lifetime,
- })
-
- row = ret[0]
-
- # If one of the room's policy's attributes isn't defined, use the matching
- # attribute from the default policy.
- # The default values will be None if no default policy has been defined, or if one
- # of the attributes is missing from the default policy.
- if row["min_lifetime"] is None:
- row["min_lifetime"] = self.config.retention_default_min_lifetime
-
- if row["max_lifetime"] is None:
- row["max_lifetime"] = self.config.retention_default_max_lifetime
-
- defer.returnValue(row)
-
-
-class RoomStore(RoomWorkerStore, SearchStore):
- def __init__(self, db_conn, hs):
- super(RoomStore, self).__init__(db_conn, hs)
-
- self.config = hs.config
-
- self.register_background_update_handler(
- "insert_room_retention", self._background_insert_retention,
- )
-
- @defer.inlineCallbacks
- def _background_insert_retention(self, progress, batch_size):
- """Retrieves a list of all rooms within a range and inserts an entry for each of
- them into the room_retention table.
- NULLs the property's columns if missing from the retention event in the room's
- state (or NULLs all of them if there's no retention event in the room's state),
- so that we fall back to the server's retention policy.
- """
-
- last_room = progress.get("room_id", "")
-
- def _background_insert_retention_txn(txn):
- txn.execute(
- """
- SELECT state.room_id, state.event_id, events.json
- FROM current_state_events as state
- LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
- WHERE state.room_id > ? AND state.type = '%s'
- ORDER BY state.room_id ASC
- LIMIT ?;
- """ % EventTypes.Retention,
- (last_room, batch_size)
- )
-
- rows = self.cursor_to_dict(txn)
-
- if not rows:
- return True
-
- for row in rows:
- if not row["json"]:
- retention_policy = {}
- else:
- ev = json.loads(row["json"])
- retention_policy = json.dumps(ev["content"])
-
- self._simple_insert_txn(
- txn=txn,
- table="room_retention",
- values={
- "room_id": row["room_id"],
- "event_id": row["event_id"],
- "min_lifetime": retention_policy.get("min_lifetime"),
- "max_lifetime": retention_policy.get("max_lifetime"),
- }
- )
-
- logger.info("Inserted %d rows into room_retention", len(rows))
-
- self._background_update_progress_txn(
- txn, "insert_room_retention", {
- "room_id": rows[-1]["room_id"],
- }
- )
-
- if batch_size > len(rows):
- return True
- else:
- return False
-
- end = yield self.runInteraction(
- "insert_room_retention",
- _background_insert_retention_txn,
- )
-
- if end:
- yield self._end_background_update("insert_room_retention")
-
- defer.returnValue(batch_size)
-
- @defer.inlineCallbacks
- def store_room(self, room_id, room_creator_user_id, is_public):
- """Stores a room.
-
- Args:
- room_id (str): The desired room ID, can be None.
- room_creator_user_id (str): The user ID of the room creator.
- is_public (bool): True to indicate that this room should appear in
- public room lists.
- Raises:
- StoreError if the room could not be stored.
- """
- try:
-
- def store_room_txn(txn, next_id):
- self._simple_insert_txn(
- txn,
- "rooms",
- {
- "room_id": room_id,
- "creator": room_creator_user_id,
- "is_public": is_public,
- },
- )
- if is_public:
- self._simple_insert_txn(
- txn,
- table="public_room_list_stream",
- values={
- "stream_id": next_id,
- "room_id": room_id,
- "visibility": is_public,
- },
- )
-
- with self._public_room_id_gen.get_next() as next_id:
- yield self.runInteraction("store_room_txn", store_room_txn, next_id)
- except Exception as e:
- logger.error("store_room with room_id=%s failed: %s", room_id, e)
- raise StoreError(500, "Problem creating room.")
-
- @defer.inlineCallbacks
- def set_room_is_public(self, room_id, is_public):
- def set_room_is_public_txn(txn, next_id):
- self._simple_update_one_txn(
- txn,
- table="rooms",
- keyvalues={"room_id": room_id},
- updatevalues={"is_public": is_public},
- )
-
- entries = self._simple_select_list_txn(
- txn,
- table="public_room_list_stream",
- keyvalues={
- "room_id": room_id,
- "appservice_id": None,
- "network_id": None,
- },
- retcols=("stream_id", "visibility"),
- )
-
- entries.sort(key=lambda r: r["stream_id"])
-
- add_to_stream = True
- if entries:
- add_to_stream = bool(entries[-1]["visibility"]) != is_public
-
- if add_to_stream:
- self._simple_insert_txn(
- txn,
- table="public_room_list_stream",
- values={
- "stream_id": next_id,
- "room_id": room_id,
- "visibility": is_public,
- "appservice_id": None,
- "network_id": None,
- },
- )
-
- with self._public_room_id_gen.get_next() as next_id:
- yield self.runInteraction(
- "set_room_is_public", set_room_is_public_txn, next_id
- )
- self.hs.get_notifier().on_new_replication_data()
-
- @defer.inlineCallbacks
- def set_room_is_public_appservice(
- self, room_id, appservice_id, network_id, is_public
- ):
- """Edit the appservice/network specific public room list.
-
- Each appservice can have a number of published room lists associated
- with them, keyed off of an appservice defined `network_id`, which
- basically represents a single instance of a bridge to a third party
- network.
-
- Args:
- room_id (str)
- appservice_id (str)
- network_id (str)
- is_public (bool): Whether to publish or unpublish the room from the
- list.
- """
-
- def set_room_is_public_appservice_txn(txn, next_id):
- if is_public:
- try:
- self._simple_insert_txn(
- txn,
- table="appservice_room_list",
- values={
- "appservice_id": appservice_id,
- "network_id": network_id,
- "room_id": room_id,
- },
- )
- except self.database_engine.module.IntegrityError:
- # We've already inserted, nothing to do.
- return
- else:
- self._simple_delete_txn(
- txn,
- table="appservice_room_list",
- keyvalues={
- "appservice_id": appservice_id,
- "network_id": network_id,
- "room_id": room_id,
- },
- )
-
- entries = self._simple_select_list_txn(
- txn,
- table="public_room_list_stream",
- keyvalues={
- "room_id": room_id,
- "appservice_id": appservice_id,
- "network_id": network_id,
- },
- retcols=("stream_id", "visibility"),
- )
-
- entries.sort(key=lambda r: r["stream_id"])
-
- add_to_stream = True
- if entries:
- add_to_stream = bool(entries[-1]["visibility"]) != is_public
-
- if add_to_stream:
- self._simple_insert_txn(
- txn,
- table="public_room_list_stream",
- values={
- "stream_id": next_id,
- "room_id": room_id,
- "visibility": is_public,
- "appservice_id": appservice_id,
- "network_id": network_id,
- },
- )
-
- with self._public_room_id_gen.get_next() as next_id:
- yield self.runInteraction(
- "set_room_is_public_appservice",
- set_room_is_public_appservice_txn,
- next_id,
- )
- self.hs.get_notifier().on_new_replication_data()
-
- def get_room_count(self):
- """Retrieve a list of all rooms
- """
-
- def f(txn):
- sql = "SELECT count(*) FROM rooms"
- txn.execute(sql)
- row = txn.fetchone()
- return row[0] or 0
-
- return self.runInteraction("get_rooms", f)
-
- def _store_room_topic_txn(self, txn, event):
- if hasattr(event, "content") and "topic" in event.content:
- self._simple_insert_txn(
- txn,
- "topics",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "topic": event.content["topic"],
- },
- )
-
- self.store_event_search_txn(
- txn, event, "content.topic", event.content["topic"]
- )
-
- def _store_room_name_txn(self, txn, event):
- if hasattr(event, "content") and "name" in event.content:
- self._simple_insert_txn(
- txn,
- "room_names",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "name": event.content["name"],
- },
- )
-
- self.store_event_search_txn(
- txn, event, "content.name", event.content["name"]
- )
-
- def _store_room_message_txn(self, txn, event):
- if hasattr(event, "content") and "body" in event.content:
- self.store_event_search_txn(
- txn, event, "content.body", event.content["body"]
- )
-
- def _store_history_visibility_txn(self, txn, event):
- self._store_content_index_txn(txn, event, "history_visibility")
-
- def _store_guest_access_txn(self, txn, event):
- self._store_content_index_txn(txn, event, "guest_access")
-
- def _store_content_index_txn(self, txn, event, key):
- if hasattr(event, "content") and key in event.content:
- sql = (
- "INSERT INTO %(key)s"
- " (event_id, room_id, %(key)s)"
- " VALUES (?, ?, ?)" % {"key": key}
- )
- txn.execute(sql, (event.event_id, event.room_id, event.content[key]))
-
- def _store_retention_policy_for_room_txn(self, txn, event):
- if (
- hasattr(event, "content")
- and ("min_lifetime" in event.content or "max_lifetime" in event.content)
- ):
- if (
- ("min_lifetime" in event.content and not isinstance(
- event.content.get("min_lifetime"), integer_types
- ))
- or ("max_lifetime" in event.content and not isinstance(
- event.content.get("max_lifetime"), integer_types
- ))
- ):
- # Ignore the event if one of the value isn't an integer.
- return
-
- self._simple_insert_txn(
- txn=txn,
- table="room_retention",
- values={
- "room_id": event.room_id,
- "event_id": event.event_id,
- "min_lifetime": event.content.get("min_lifetime"),
- "max_lifetime": event.content.get("max_lifetime"),
- },
- )
-
- self._invalidate_cache_and_stream(
- txn, self.get_retention_policy_for_room, (event.room_id,)
- )
-
- def add_event_report(
- self, room_id, event_id, user_id, reason, content, received_ts
- ):
- next_id = self._event_reports_id_gen.get_next()
- return self._simple_insert(
- table="event_reports",
- values={
- "id": next_id,
- "received_ts": received_ts,
- "room_id": room_id,
- "event_id": event_id,
- "user_id": user_id,
- "reason": reason,
- "content": json.dumps(content),
- },
- desc="add_event_report",
- )
-
- def get_current_public_room_stream_id(self):
- return self._public_room_id_gen.get_current_token()
-
- def get_all_new_public_rooms(self, prev_id, current_id, limit):
- def get_all_new_public_rooms(txn):
- sql = """
- SELECT stream_id, room_id, visibility, appservice_id, network_id
- FROM public_room_list_stream
- WHERE stream_id > ? AND stream_id <= ?
- ORDER BY stream_id ASC
- LIMIT ?
- """
-
- txn.execute(sql, (prev_id, current_id, limit))
- return txn.fetchall()
-
- if prev_id == current_id:
- return defer.succeed([])
-
- return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
-
- @defer.inlineCallbacks
- def block_room(self, room_id, user_id):
- """Marks the room as blocked. Can be called multiple times.
-
- Args:
- room_id (str): Room to block
- user_id (str): Who blocked it
-
- Returns:
- Deferred
- """
- yield self._simple_upsert(
- table="blocked_rooms",
- keyvalues={"room_id": room_id},
- values={},
- insertion_values={"user_id": user_id},
- desc="block_room",
- )
- yield self.runInteraction(
- "block_room_invalidation",
- self._invalidate_cache_and_stream,
- self.is_room_blocked,
- (room_id,),
- )
-
- def get_media_mxcs_in_room(self, room_id):
- """Retrieves all the local and remote media MXC URIs in a given room
-
- Args:
- room_id (str)
-
- Returns:
- The local and remote media as a lists of tuples where the key is
- the hostname and the value is the media ID.
- """
-
- def _get_media_mxcs_in_room_txn(txn):
- local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
- local_media_mxcs = []
- remote_media_mxcs = []
-
- # Convert the IDs to MXC URIs
- for media_id in local_mxcs:
- local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
- for hostname, media_id in remote_mxcs:
- remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
-
- return local_media_mxcs, remote_media_mxcs
-
- return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
-
- def quarantine_media_ids_in_room(self, room_id, quarantined_by):
- """For a room loops through all events with media and quarantines
- the associated media
- """
-
- def _quarantine_media_in_room_txn(txn):
- local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
- total_media_quarantined = 0
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany(
- """
- UPDATE local_media_repository
- SET quarantined_by = ?
- WHERE media_id = ?
- """,
- ((quarantined_by, media_id) for media_id in local_mxcs),
- )
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin = ? AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_mxcs
- ),
- )
-
- total_media_quarantined += len(local_mxcs)
- total_media_quarantined += len(remote_mxcs)
-
- return total_media_quarantined
-
- return self.runInteraction(
- "quarantine_media_in_room", _quarantine_media_in_room_txn
- )
-
- def _get_media_mxcs_in_room_txn(self, txn, room_id):
- """Retrieves all the local and remote media MXC URIs in a given room
-
- Args:
- txn (cursor)
- room_id (str)
-
- Returns:
- The local and remote media as a lists of tuples where the key is
- the hostname and the value is the media ID.
- """
- mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
-
- next_token = self.get_current_events_token() + 1
- local_media_mxcs = []
- remote_media_mxcs = []
-
- while next_token:
- sql = """
- SELECT stream_ordering, json FROM events
- JOIN event_json USING (room_id, event_id)
- WHERE room_id = ?
- AND stream_ordering < ?
- AND contains_url = ? AND outlier = ?
- ORDER BY stream_ordering DESC
- LIMIT ?
- """
- txn.execute(sql, (room_id, next_token, True, False, 100))
-
- next_token = None
- for stream_ordering, content_json in txn:
- next_token = stream_ordering
- event_json = json.loads(content_json)
- content = event_json["content"]
- content_url = content.get("url")
- thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
- for url in (content_url, thumbnail_url):
- if not url:
- continue
- matches = mxc_re.match(url)
- if matches:
- hostname = matches.group(1)
- media_id = matches.group(2)
- if hostname == self.hs.hostname:
- local_media_mxcs.append(media_id)
- else:
- remote_media_mxcs.append((hostname, media_id))
-
- return local_media_mxcs, remote_media_mxcs
-
- @defer.inlineCallbacks
- def get_rooms_for_retention_period_in_range(self, min_ms, max_ms, include_null=False):
- """Retrieves all of the rooms within the given retention range.
-
- Optionally includes the rooms which don't have a retention policy.
-
- Args:
- min_ms (int|None): Duration in milliseconds that define the lower limit of
- the range to handle (exclusive). If None, doesn't set a lower limit.
- max_ms (int|None): Duration in milliseconds that define the upper limit of
- the range to handle (inclusive). If None, doesn't set an upper limit.
- include_null (bool): Whether to include rooms which retention policy is NULL
- in the returned set.
-
- Returns:
- dict[str, dict]: The rooms within this range, along with their retention
- policy. The key is "room_id", and maps to a dict describing the retention
- policy associated with this room ID. The keys for this nested dict are
- "min_lifetime" (int|None), and "max_lifetime" (int|None).
- """
-
- def get_rooms_for_retention_period_in_range_txn(txn):
- range_conditions = []
- args = []
-
- if min_ms is not None:
- range_conditions.append("max_lifetime > ?")
- args.append(min_ms)
-
- if max_ms is not None:
- range_conditions.append("max_lifetime <= ?")
- args.append(max_ms)
-
- # Do a first query which will retrieve the rooms that have a retention policy
- # in their current state.
- sql = """
- SELECT room_id, min_lifetime, max_lifetime FROM room_retention
- INNER JOIN current_state_events USING (event_id, room_id)
- """
-
- if len(range_conditions):
- sql += " WHERE (" + " AND ".join(range_conditions) + ")"
-
- if include_null:
- sql += " OR max_lifetime IS NULL"
-
- txn.execute(sql, args)
-
- rows = self.cursor_to_dict(txn)
- rooms_dict = {}
-
- for row in rows:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": row["min_lifetime"],
- "max_lifetime": row["max_lifetime"],
- }
-
- if include_null:
- # If required, do a second query that retrieves all of the rooms we know
- # of so we can handle rooms with no retention policy.
- sql = "SELECT DISTINCT room_id FROM current_state_events"
-
- txn.execute(sql)
-
- rows = self.cursor_to_dict(txn)
-
- # If a room isn't already in the dict (i.e. it doesn't have a retention
- # policy in its state), add it with a null policy.
- for row in rows:
- if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": None,
- "max_lifetime": None,
- }
-
- return rooms_dict
-
- rooms = yield self.runInteraction(
- "get_rooms_for_retention_period_in_range",
- get_rooms_for_retention_period_in_range_txn,
- )
-
- defer.returnValue(rooms)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 7617913326..8c4a83a840 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,20 +17,6 @@
import logging
from collections import namedtuple
-from six import iteritems, itervalues
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes, Membership
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.types import get_domain_from_id
-from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.util.stringutils import to_ascii
-
logger = logging.getLogger(__name__)
@@ -51,780 +37,3 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
# a given membership type, suitable for use in calculating heroes for a room.
# "count" points to the total numberr of users of a given membership type.
MemberSummary = namedtuple("MemberSummary", ("members", "count"))
-
-_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
-
-
-class RoomMemberWorkerStore(EventsWorkerStore):
- @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
- def get_hosts_in_room(self, room_id, cache_context):
- """Returns the set of all hosts currently in the room
- """
- user_ids = yield self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
- defer.returnValue(hosts)
-
- @cached(max_entries=100000, iterable=True)
- def get_users_in_room(self, room_id):
- def f(txn):
- sql = (
- "SELECT m.user_id FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id "
- " AND m.room_id = c.room_id "
- " AND m.user_id = c.state_key"
- " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
- )
-
- txn.execute(sql, (room_id, Membership.JOIN))
- return [to_ascii(r[0]) for r in txn]
-
- return self.runInteraction("get_users_in_room", f)
-
- @cached(max_entries=100000)
- def get_room_summary(self, room_id):
- """ Get the details of a room roughly suitable for use by the room
- summary extension to /sync. Useful when lazy loading room members.
- Args:
- room_id (str): The room ID to query
- Returns:
- Deferred[dict[str, MemberSummary]:
- dict of membership states, pointing to a MemberSummary named tuple.
- """
-
- def _get_room_summary_txn(txn):
- # first get counts.
- # We do this all in one transaction to keep the cache small.
- # FIXME: get rid of this when we have room_stats
- sql = """
- SELECT count(*), m.membership FROM room_memberships as m
- INNER JOIN current_state_events as c
- ON m.event_id = c.event_id
- AND m.room_id = c.room_id
- AND m.user_id = c.state_key
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
-
- txn.execute(sql, (room_id,))
- res = {}
- for count, membership in txn:
- summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
-
- # we order by membership and then fairly arbitrarily by event_id so
- # heroes are consistent
- sql = """
- SELECT m.user_id, m.membership, m.event_id
- FROM room_memberships as m
- INNER JOIN current_state_events as c
- ON m.event_id = c.event_id
- AND m.room_id = c.room_id
- AND m.user_id = c.state_key
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- ORDER BY
- CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
- m.event_id ASC
- LIMIT ?
- """
-
- # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
- txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
- for user_id, membership, event_id in txn:
- summary = res[to_ascii(membership)]
- # we will always have a summary for this membership type at this
- # point given the summary currently contains the counts.
- members = summary.members
- members.append((to_ascii(user_id), to_ascii(event_id)))
-
- return res
-
- return self.runInteraction("get_room_summary", _get_room_summary_txn)
-
- def _get_user_counts_in_room_txn(self, txn, room_id):
- """
- Get the user count in a room by membership.
-
- Args:
- room_id (str)
- membership (Membership)
-
- Returns:
- Deferred[int]
- """
- sql = """
- SELECT m.membership, count(*) FROM room_memberships as m
- INNER JOIN current_state_events as c USING(event_id)
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
-
- txn.execute(sql, (room_id,))
- return {row[0]: row[1] for row in txn}
-
- @cached()
- def get_invited_rooms_for_user(self, user_id):
- """ Get all the rooms the user is invited to
- Args:
- user_id (str): The user ID.
- Returns:
- A deferred list of RoomsForUser.
- """
-
- return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
-
- @defer.inlineCallbacks
- def get_invite_for_user_in_room(self, user_id, room_id):
- """Gets the invite for the given user and room
-
- Args:
- user_id (str)
- room_id (str)
-
- Returns:
- Deferred: Resolves to either a RoomsForUser or None if no invite was
- found.
- """
- invites = yield self.get_invited_rooms_for_user(user_id)
- for invite in invites:
- if invite.room_id == room_id:
- defer.returnValue(invite)
- defer.returnValue(None)
-
- def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
- """ Get all the rooms for this user where the membership for this user
- matches one in the membership list.
-
- Args:
- user_id (str): The user ID.
- membership_list (list): A list of synapse.api.constants.Membership
- values which the user must be in.
- Returns:
- A list of dictionary objects, with room_id, membership and sender
- defined.
- """
- if not membership_list:
- return defer.succeed(None)
-
- return self.runInteraction(
- "get_rooms_for_user_where_membership_is",
- self._get_rooms_for_user_where_membership_is_txn,
- user_id,
- membership_list,
- )
-
- def _get_rooms_for_user_where_membership_is_txn(
- self, txn, user_id, membership_list
- ):
-
- do_invite = Membership.INVITE in membership_list
- membership_list = [m for m in membership_list if m != Membership.INVITE]
-
- results = []
- if membership_list:
- where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
- " OR ".join(["membership = ?" for _ in membership_list]),
- )
-
- args = [user_id]
- args.extend(membership_list)
-
- sql = (
- "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
- " FROM current_state_events as c"
- " INNER JOIN room_memberships as m"
- " ON m.event_id = c.event_id"
- " INNER JOIN events as e"
- " ON e.event_id = c.event_id"
- " AND m.room_id = c.room_id"
- " AND m.user_id = c.state_key"
- " WHERE c.type = 'm.room.member' AND %s"
- ) % (where_clause,)
-
- txn.execute(sql, args)
- results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
-
- if do_invite:
- sql = (
- "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
- " FROM local_invites as i"
- " INNER JOIN events as e USING (event_id)"
- " WHERE invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(sql, (user_id,))
- results.extend(
- RoomsForUser(
- room_id=r["room_id"],
- sender=r["inviter"],
- event_id=r["event_id"],
- stream_ordering=r["stream_ordering"],
- membership=Membership.INVITE,
- )
- for r in self.cursor_to_dict(txn)
- )
-
- return results
-
- @cachedInlineCallbacks(max_entries=500000, iterable=True)
- def get_rooms_for_user_with_stream_ordering(self, user_id):
- """Returns a set of room_ids the user is currently joined to
-
- Args:
- user_id (str)
-
- Returns:
- Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
- the rooms the user is in currently, along with the stream ordering
- of the most recent join for that user and room.
- """
- rooms = yield self.get_rooms_for_user_where_membership_is(
- user_id, membership_list=[Membership.JOIN]
- )
- defer.returnValue(
- frozenset(
- GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
- for r in rooms
- )
- )
-
- @defer.inlineCallbacks
- def get_rooms_for_user(self, user_id, on_invalidate=None):
- """Returns a set of room_ids the user is currently joined to
- """
- rooms = yield self.get_rooms_for_user_with_stream_ordering(
- user_id, on_invalidate=on_invalidate
- )
- defer.returnValue(frozenset(r.room_id for r in rooms))
-
- @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
- def get_users_who_share_room_with_user(self, user_id, cache_context):
- """Returns the set of users who share a room with `user_id`
- """
- room_ids = yield self.get_rooms_for_user(
- user_id, on_invalidate=cache_context.invalidate
- )
-
- user_who_share_room = set()
- for room_id in room_ids:
- user_ids = yield self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- user_who_share_room.update(user_ids)
-
- defer.returnValue(user_who_share_room)
-
- @defer.inlineCallbacks
- def get_joined_users_from_context(self, event, context):
- state_group = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- current_state_ids = yield context.get_current_state_ids(self)
- result = yield self._get_joined_users_from_context(
- event.room_id, state_group, current_state_ids, event=event, context=context
- )
- defer.returnValue(result)
-
- def get_joined_users_from_state(self, room_id, state_entry):
- state_group = state_entry.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- return self._get_joined_users_from_context(
- room_id, state_group, state_entry.state, context=state_entry
- )
-
- @cachedInlineCallbacks(
- num_args=2, cache_context=True, iterable=True, max_entries=100000
- )
- def _get_joined_users_from_context(
- self,
- room_id,
- state_group,
- current_state_ids,
- cache_context,
- event=None,
- context=None,
- ):
- # We don't use `state_group`, it's there so that we can cache based
- # on it. However, it's important that it's never None, since two current_states
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
-
- users_in_room = {}
- member_event_ids = [
- e_id
- for key, e_id in iteritems(current_state_ids)
- if key[0] == EventTypes.Member
- ]
-
- if context is not None:
- # If we have a context with a delta from a previous state group,
- # check if we also have the result from the previous group in cache.
- # If we do then we can reuse that result and simply update it with
- # any membership changes in `delta_ids`
- if context.prev_group and context.delta_ids:
- prev_res = self._get_joined_users_from_context.cache.get(
- (room_id, context.prev_group), None
- )
- if prev_res and isinstance(prev_res, dict):
- users_in_room = dict(prev_res)
- member_event_ids = [
- e_id
- for key, e_id in iteritems(context.delta_ids)
- if key[0] == EventTypes.Member
- ]
- for etype, state_key in context.delta_ids:
- users_in_room.pop(state_key, None)
-
- # We check if we have any of the member event ids in the event cache
- # before we ask the DB
-
- # We don't update the event cache hit ratio as it completely throws off
- # the hit ratio counts. After all, we don't populate the cache if we
- # miss it here
- event_map = self._get_events_from_cache(
- member_event_ids, allow_rejected=False, update_metrics=False
- )
-
- missing_member_event_ids = []
- for event_id in member_event_ids:
- ev_entry = event_map.get(event_id)
- if ev_entry:
- if ev_entry.event.membership == Membership.JOIN:
- users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
- display_name=to_ascii(
- ev_entry.event.content.get("displayname", None)
- ),
- avatar_url=to_ascii(
- ev_entry.event.content.get("avatar_url", None)
- ),
- )
- else:
- missing_member_event_ids.append(event_id)
-
- if missing_member_event_ids:
- rows = yield self._simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=missing_member_event_ids,
- retcols=('user_id', 'display_name', 'avatar_url'),
- keyvalues={"membership": Membership.JOIN},
- batch_size=500,
- desc="_get_joined_users_from_context",
- )
-
- users_in_room.update(
- {
- to_ascii(row["user_id"]): ProfileInfo(
- avatar_url=to_ascii(row["avatar_url"]),
- display_name=to_ascii(row["display_name"]),
- )
- for row in rows
- }
- )
-
- if event is not None and event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- if event.event_id in member_event_ids:
- users_in_room[to_ascii(event.state_key)] = ProfileInfo(
- display_name=to_ascii(event.content.get("displayname", None)),
- avatar_url=to_ascii(event.content.get("avatar_url", None)),
- )
-
- defer.returnValue(users_in_room)
-
- @cachedInlineCallbacks(max_entries=10000)
- def is_host_joined(self, room_id, host):
- if '%' in host or '_' in host:
- raise Exception("Invalid host name")
-
- sql = """
- SELECT state_key FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
- WHERE membership = 'join'
- AND type = 'm.room.member'
- AND c.room_id = ?
- AND state_key LIKE ?
- LIMIT 1
- """
-
- # We do need to be careful to ensure that host doesn't have any wild cards
- # in it, but we checked above for known ones and we'll check below that
- # the returned user actually has the correct domain.
- like_clause = "%:" + host
-
- rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
-
- if not rows:
- defer.returnValue(False)
-
- user_id = rows[0][0]
- if get_domain_from_id(user_id) != host:
- # This can only happen if the host name has something funky in it
- raise Exception("Invalid host name")
-
- defer.returnValue(True)
-
- @cachedInlineCallbacks()
- def was_host_joined(self, room_id, host):
- """Check whether the server is or ever was in the room.
-
- Args:
- room_id (str)
- host (str)
-
- Returns:
- Deferred: Resolves to True if the host is/was in the room, otherwise
- False.
- """
- if '%' in host or '_' in host:
- raise Exception("Invalid host name")
-
- sql = """
- SELECT user_id FROM room_memberships
- WHERE room_id = ?
- AND user_id LIKE ?
- AND membership = 'join'
- LIMIT 1
- """
-
- # We do need to be careful to ensure that host doesn't have any wild cards
- # in it, but we checked above for known ones and we'll check below that
- # the returned user actually has the correct domain.
- like_clause = "%:" + host
-
- rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
-
- if not rows:
- defer.returnValue(False)
-
- user_id = rows[0][0]
- if get_domain_from_id(user_id) != host:
- # This can only happen if the host name has something funky in it
- raise Exception("Invalid host name")
-
- defer.returnValue(True)
-
- def get_joined_hosts(self, room_id, state_entry):
- state_group = state_entry.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- return self._get_joined_hosts(
- room_id, state_group, state_entry.state, state_entry=state_entry
- )
-
- @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
- # @defer.inlineCallbacks
- def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
-
- cache = self._get_joined_hosts_cache(room_id)
- joined_hosts = yield cache.get_destinations(state_entry)
-
- defer.returnValue(joined_hosts)
-
- @cached(max_entries=10000)
- def _get_joined_hosts_cache(self, room_id):
- return _JoinedHostsCache(self, room_id)
-
- @cachedInlineCallbacks(num_args=2)
- def did_forget(self, user_id, room_id):
- """Returns whether user_id has elected to discard history for room_id.
-
- Returns False if they have since re-joined."""
-
- def f(txn):
- sql = (
- "SELECT"
- " COUNT(*)"
- " FROM"
- " room_memberships"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- " AND"
- " forgotten = 0"
- )
- txn.execute(sql, (user_id, room_id))
- rows = txn.fetchall()
- return rows[0][0]
-
- count = yield self.runInteraction("did_forget_membership", f)
- defer.returnValue(count == 0)
-
-
-class RoomMemberStore(RoomMemberWorkerStore):
- def __init__(self, db_conn, hs):
- super(RoomMemberStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
- _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
- )
-
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
- self._simple_insert_many_txn(
- txn,
- table="room_memberships",
- values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
- }
- for event in events
- ],
- )
-
- for event in events:
- txn.call_after(
- self._membership_stream_cache.entity_has_changed,
- event.state_key,
- event.internal_metadata.stream_ordering,
- )
- txn.call_after(
- self.get_invited_rooms_for_user.invalidate, (event.state_key,)
- )
-
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened. If the event is an
- # outlier it is only current if its an "out of band membership",
- # like a remote invite or a rejection of a remote invite.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_out_of_band_membership()
- )
- is_mine = self.hs.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self._simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- },
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(
- sql,
- (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ),
- )
-
- @defer.inlineCallbacks
- def locally_reject_invite(self, user_id, room_id):
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
- with self._stream_id_gen.get_next() as stream_ordering:
- yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
- def forget(self, user_id, room_id):
- """Indicate that user_id wishes to discard history for room_id."""
-
- def f(txn):
- sql = (
- "UPDATE"
- " room_memberships"
- " SET"
- " forgotten = 1"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- )
- txn.execute(sql, (user_id, room_id))
-
- self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
-
- return self.runInteraction("forget_membership", f)
-
- @defer.inlineCallbacks
- def _background_add_membership_profile(self, progress, batch_size):
- target_min_stream_id = progress.get(
- "target_min_stream_id_inclusive", self._min_stream_order_on_start
- )
- max_stream_id = progress.get(
- "max_stream_id_exclusive", self._stream_order_on_start + 1
- )
-
- INSERT_CLUMP_SIZE = 1000
-
- def add_membership_profile_txn(txn):
- sql = """
- SELECT stream_ordering, event_id, events.room_id, event_json.json
- FROM events
- INNER JOIN event_json USING (event_id)
- INNER JOIN room_memberships USING (event_id)
- WHERE ? <= stream_ordering AND stream_ordering < ?
- AND type = 'm.room.member'
- ORDER BY stream_ordering DESC
- LIMIT ?
- """
-
- txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
-
- rows = self.cursor_to_dict(txn)
- if not rows:
- return 0
-
- min_stream_id = rows[-1]["stream_ordering"]
-
- to_update = []
- for row in rows:
- event_id = row["event_id"]
- room_id = row["room_id"]
- try:
- event_json = json.loads(row["json"])
- content = event_json['content']
- except Exception:
- continue
-
- display_name = content.get("displayname", None)
- avatar_url = content.get("avatar_url", None)
-
- if display_name or avatar_url:
- to_update.append((display_name, avatar_url, event_id, room_id))
-
- to_update_sql = """
- UPDATE room_memberships SET display_name = ?, avatar_url = ?
- WHERE event_id = ? AND room_id = ?
- """
- for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
- clump = to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(to_update_sql, clump)
-
- progress = {
- "target_min_stream_id_inclusive": target_min_stream_id,
- "max_stream_id_exclusive": min_stream_id,
- }
-
- self._background_update_progress_txn(
- txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
- )
-
- return len(rows)
-
- result = yield self.runInteraction(
- _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
- )
-
- if not result:
- yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
-
- defer.returnValue(result)
-
-
-class _JoinedHostsCache(object):
- """Cache for joined hosts in a room that is optimised to handle updates
- via state deltas.
- """
-
- def __init__(self, store, room_id):
- self.store = store
- self.room_id = room_id
-
- self.hosts_to_joined_users = {}
-
- self.state_group = object()
-
- self.linearizer = Linearizer("_JoinedHostsCache")
-
- self._len = 0
-
- @defer.inlineCallbacks
- def get_destinations(self, state_entry):
- """Get set of destinations for a state entry
-
- Args:
- state_entry(synapse.state._StateCacheEntry)
- """
- if state_entry.state_group == self.state_group:
- defer.returnValue(frozenset(self.hosts_to_joined_users))
-
- with (yield self.linearizer.queue(())):
- if state_entry.state_group == self.state_group:
- pass
- elif state_entry.prev_group == self.state_group:
- for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
- if typ != EventTypes.Member:
- continue
-
- host = intern_string(get_domain_from_id(state_key))
- user_id = state_key
- known_joins = self.hosts_to_joined_users.setdefault(host, set())
-
- event = yield self.store.get_event(event_id)
- if event.membership == Membership.JOIN:
- known_joins.add(user_id)
- else:
- known_joins.discard(user_id)
-
- if not known_joins:
- self.hosts_to_joined_users.pop(host, None)
- else:
- joined_users = yield self.store.get_joined_users_from_state(
- self.room_id, state_entry
- )
-
- self.hosts_to_joined_users = {}
- for user_id in joined_users:
- host = intern_string(get_domain_from_id(user_id))
- self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
-
- if state_entry.state_group:
- self.state_group = state_entry.state_group
- else:
- self.state_group = object()
- self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
- defer.returnValue(frozenset(self.hosts_to_joined_users))
-
- def __len__(self):
- return self._len
diff --git a/synapse/storage/schema/delta/35/00background_updates_add_col.sql b/synapse/storage/schema/delta/35/00background_updates_add_col.sql
new file mode 100644
index 0000000000..c2d2a4f836
--- /dev/null
+++ b/synapse/storage/schema/delta/35/00background_updates_add_col.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
diff --git a/synapse/storage/schema/full_schemas/54/full.sql b/synapse/storage/schema/full_schemas/54/full.sql
new file mode 100644
index 0000000000..1005880466
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/54/full.sql
@@ -0,0 +1,8 @@
+
+
+CREATE TABLE background_updates (
+ update_name text NOT NULL,
+ progress_json text NOT NULL,
+ depends_on text,
+ CONSTRAINT background_updates_uniqueness UNIQUE (update_name)
+);
diff --git a/synapse/storage/schema/full_schemas/README.txt b/synapse/storage/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0bfe1b4550..c522c80922 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,43 +14,21 @@
# limitations under the License.
import logging
-from collections import namedtuple
+from typing import Iterable, List, TypeVar
from six import iteritems, itervalues
-from six.moves import range
import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.storage.engines import PostgresEngine
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.util.caches import get_cache_factor_for, intern_string
-from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.dictionary_cache import DictionaryCache
-from synapse.util.stringutils import to_ascii
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
-
-MAX_STATE_DELTA_HOPS = 100
-
-
-class _GetStateGroupDelta(
- namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
- """Return type of get_state_group_delta that implements __len__, which lets
- us use the itrable flag when caching
- """
-
- __slots__ = []
-
- def __len__(self):
- return len(self.delta_ids) if self.delta_ids else 0
+# Used for generic functions below
+T = TypeVar("T")
@attr.s(slots=True)
@@ -260,14 +238,14 @@ class StateFilter(object):
return len(self.concrete_types())
- def filter_state(self, state_dict):
+ def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
"""Returns the state filtered with by this StateFilter
Args:
- state (dict[tuple[str, str], Any]): The state map to filter
+ state: The state map to filter
Returns:
- dict[tuple[str, str], Any]: The filtered state map
+ The filtered state map
"""
if self.is_full():
return dict(state_dict)
@@ -353,248 +331,23 @@ class StateFilter(object):
return member_filter, non_member_filter
-# this inherits from EventsWorkerStore because it calls self.get_events
-class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
- """The parts of StateGroupStore that can be called from workers.
+class StateGroupStorage(object):
+ """High level interface to fetching state for event.
"""
- STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
- STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
- CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
- def __init__(self, db_conn, hs):
- super(StateGroupWorkerStore, self).__init__(db_conn, hs)
-
- # Originally the state store used a single DictionaryCache to cache the
- # event IDs for the state types in a given state group to avoid hammering
- # on the state_group* tables.
- #
- # The point of using a DictionaryCache is that it can cache a subset
- # of the state events for a given state group (i.e. a subset of the keys for a
- # given dict which is an entry in the cache for a given state group ID).
- #
- # However, this poses problems when performing complicated queries
- # on the store - for instance: "give me all the state for this group, but
- # limit members to this subset of users", as DictionaryCache's API isn't
- # rich enough to say "please cache any of these fields, apart from this subset".
- # This is problematic when lazy loading members, which requires this behaviour,
- # as without it the cache has no choice but to speculatively load all
- # state events for the group, which negates the efficiency being sought.
- #
- # Rather than overcomplicating DictionaryCache's API, we instead split the
- # state_group_cache into two halves - one for tracking non-member events,
- # and the other for tracking member_events. This means that lazy loading
- # queries can be made in a cache-friendly manner by querying both caches
- # separately and then merging the result. So for the example above, you
- # would query the members cache for a specific subset of state keys
- # (which DictionaryCache will handle efficiently and fine) and the non-members
- # cache for all state (which DictionaryCache will similarly handle fine)
- # and then just merge the results together.
- #
- # We size the non-members cache to be smaller than the members cache as the
- # vast majority of state in Matrix (today) is member events.
-
- self._state_group_cache = DictionaryCache(
- "*stateGroupCache*",
- # TODO: this hasn't been tuned yet
- 50000 * get_cache_factor_for("stateGroupCache"),
- )
- self._state_group_members_cache = DictionaryCache(
- "*stateGroupMembersCache*",
- 500000 * get_cache_factor_for("stateGroupMembersCache"),
- )
-
- @defer.inlineCallbacks
- def get_room_version(self, room_id):
- """Get the room_version of a given room
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[str]
-
- Raises:
- NotFoundError if the room is unknown
- """
- # for now we do this by looking at the create event. We may want to cache this
- # more intelligently in future.
-
- # Retrieve the room's create event
- create_event = yield self.get_create_event_for_room(room_id)
- defer.returnValue(create_event.content.get("room_version", "1"))
-
- @defer.inlineCallbacks
- def get_room_predecessor(self, room_id):
- """Get the predecessor room of an upgraded room if one exists.
- Otherwise return None.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[unicode|None]: predecessor room id
-
- Raises:
- NotFoundError if the room is unknown
- """
- # Retrieve the room's create event
- create_event = yield self.get_create_event_for_room(room_id)
-
- # Return predecessor if present
- defer.returnValue(create_event.content.get("predecessor", None))
-
- @defer.inlineCallbacks
- def get_create_event_for_room(self, room_id):
- """Get the create state event for a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[EventBase]: The room creation event.
-
- Raises:
- NotFoundError if the room is unknown
- """
- state_ids = yield self.get_current_state_ids(room_id)
- create_id = state_ids.get((EventTypes.Create, ""))
+ def __init__(self, hs, stores):
+ self.stores = stores
- # If we can't find the create event, assume we've hit a dead end
- if not create_id:
- raise NotFoundError("Unknown room %s" % (room_id))
-
- # Retrieve the room's create event and return
- create_event = yield self.get_event(create_id)
- defer.returnValue(create_event)
-
- @cached(max_entries=100000, iterable=True)
- def get_current_state_ids(self, room_id):
- """Get the current state event ids for a room based on the
- current_state_events table.
-
- Args:
- room_id (str)
-
- Returns:
- deferred: dict of (type, state_key) -> event_id
- """
-
- def _get_current_state_ids_txn(txn):
- txn.execute(
- """SELECT type, state_key, event_id FROM current_state_events
- WHERE room_id = ?
- """,
- (room_id,),
- )
-
- return {
- (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
- }
-
- return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
-
- # FIXME: how should this be cached?
- def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
- """Get the current state event of a given type for a room based on the
- current_state_events table. This may not be as up-to-date as the result
- of doing a fresh state resolution as per state_handler.get_current_state
-
- Args:
- room_id (str)
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
- event ID.
- """
-
- def _get_filtered_current_state_ids_txn(txn):
- results = {}
- sql = """
- SELECT type, state_key, event_id FROM current_state_events
- WHERE room_id = ?
- """
-
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
- if where_clause:
- sql += " AND (%s)" % (where_clause,)
-
- args = [room_id]
- args.extend(where_args)
- txn.execute(sql, args)
- for row in txn:
- typ, state_key, event_id = row
- key = (intern_string(typ), intern_string(state_key))
- results[key] = event_id
-
- return results
-
- return self.runInteraction(
- "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
- )
-
- @defer.inlineCallbacks
- def get_canonical_alias_for_room(self, room_id):
- """Get canonical alias for room, if any
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[str|None]: The canonical alias, if any
- """
-
- state = yield self.get_filtered_current_state_ids(
- room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
- )
-
- event_id = state.get((EventTypes.CanonicalAlias, ""))
- if not event_id:
- return
-
- event = yield self.get_event(event_id, allow_none=True)
- if not event:
- return
-
- defer.returnValue(event.content.get("canonical_alias"))
-
- @cached(max_entries=10000, iterable=True)
- def get_state_group_delta(self, state_group):
+ def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
- (prev_group, delta_ids), where both may be None.
+ Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
+ (prev_group, delta_ids)
"""
- def _get_state_group_delta_txn(txn):
- prev_group = self._simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": state_group},
- retcol="prev_state_group",
- allow_none=True,
- )
-
- if not prev_group:
- return _GetStateGroupDelta(None, None)
-
- delta_ids = self._simple_select_list_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- retcols=("type", "state_key", "event_id"),
- )
-
- return _GetStateGroupDelta(
- prev_group,
- {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
- )
-
- return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
+ return self.stores.state.get_state_group_delta(state_group)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@@ -605,18 +358,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
event_ids (iterable[str]): ids of the events
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
+ Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
- defer.returnValue({})
+ return {}
- event_to_groups = yield self._get_state_group_for_events(event_ids)
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups)
+ group_to_state = yield self.stores.state._get_state_for_groups(groups)
- defer.returnValue(group_to_state)
+ return group_to_state
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
@@ -630,22 +383,21 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
group_to_state = yield self._get_state_for_groups((state_group,))
- defer.returnValue(group_to_state[state_group])
+ return group_to_state[state_group]
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
-
Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events.
"""
if not event_ids:
- defer.returnValue({})
+ return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
- state_event_map = yield self.get_events(
+ state_event_map = yield self.stores.main.get_events(
[
ev_id
for group_ids in itervalues(group_to_ids)
@@ -654,164 +406,50 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
get_prev_content=False,
)
- defer.returnValue(
- {
- group: [
- state_event_map[v]
- for v in itervalues(event_id_map)
- if v in state_event_map
- ]
- for group, event_id_map in iteritems(group_to_ids)
- }
- )
+ return {
+ group: [
+ state_event_map[v]
+ for v in itervalues(event_id_map)
+ if v in state_event_map
+ ]
+ for group, event_id_map in iteritems(group_to_ids)
+ }
- @defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, state_filter):
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
- groups(list[int]): list of state group IDs to query
- state_filter (StateFilter): The state filter used to fetch state
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
- results = {}
-
- chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
- for chunk in chunks:
- res = yield self.runInteraction(
- "_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn,
- chunk,
- state_filter,
- )
- results.update(res)
- defer.returnValue(results)
-
- def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter=StateFilter.all()
- ):
- results = {group: {} for group in groups}
-
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
- # Unless the filter clause is empty, we're going to append it after an
- # existing where clause
- if where_clause:
- where_clause = " AND (%s)" % (where_clause,)
-
- if isinstance(self.database_engine, PostgresEngine):
- # Temporarily disable sequential scans in this transaction. This is
- # a temporary hack until we can add the right indices in
- txn.execute("SET LOCAL enable_seqscan=off")
-
- # The below query walks the state_group tree so that the "state"
- # table includes all state_groups in the tree. It then joins
- # against `state_groups_state` to fetch the latest state.
- # It assumes that previous state groups are always numerically
- # lesser.
- # The PARTITION is used to get the event_id in the greatest state
- # group for the given type, state_key.
- # This may return multiple rows per (type, state_key), but last_value
- # should be the same.
- sql = """
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT DISTINCT type, state_key, last_value(event_id) OVER (
- PARTITION BY type, state_key ORDER BY state_group ASC
- ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
- ) AS event_id FROM state_groups_state
- WHERE state_group IN (
- SELECT state_group FROM state
- )
- """
-
- for group in groups:
- args = [group]
- args.extend(where_args)
-
- txn.execute(sql + where_clause, args)
- for row in txn:
- typ, state_key, event_id = row
- key = (typ, state_key)
- results[group][key] = event_id
- else:
- max_entries_returned = state_filter.max_entries_returned()
-
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- for group in groups:
- next_group = group
-
- while next_group:
- # We did this before by getting the list of group ids, and
- # then passing that list to sqlite to get latest event for
- # each (type, state_key). However, that was terribly slow
- # without the right indices (which we can't add until
- # after we finish deduping state, which requires this func)
- args = [next_group]
- args.extend(where_args)
-
- txn.execute(
- "SELECT type, state_key, event_id FROM state_groups_state"
- " WHERE state_group = ? " + where_clause,
- args,
- )
- results[group].update(
- ((typ, state_key), event_id)
- for typ, state_key, event_id in txn
- if (typ, state_key) not in results[group]
- )
-
- # If the number of entries in the (type,state_key)->event_id dict
- # matches the number of (type,state_keys) types we were searching
- # for, then we must have found them all, so no need to go walk
- # further down the tree... UNLESS our types filter contained
- # wildcards (i.e. Nones) in which case we have to do an exhaustive
- # search
- if (
- max_entries_returned is not None
- and len(results[group]) == max_entries_returned
- ):
- break
-
- next_group = self._simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
-
- return results
+ return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
-
Args:
event_ids (list[string])
state_filter (StateFilter): The state filter used to fetch state
from the database.
-
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
- event_to_groups = yield self._get_state_group_for_events(event_ids)
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, state_filter)
+ group_to_state = yield self.stores.state._get_state_for_groups(
+ groups, state_filter
+ )
- state_event_map = yield self.get_events(
+ state_event_map = yield self.stores.main.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
get_prev_content=False,
)
@@ -825,7 +463,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for event_id, group in iteritems(event_to_groups)
}
- defer.returnValue({event: event_to_state[event] for event in event_ids})
+ return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
@@ -841,17 +479,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
"""
- event_to_groups = yield self._get_state_group_for_events(event_ids)
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, state_filter)
+ group_to_state = yield self.stores.state._get_state_for_groups(
+ groups, state_filter
+ )
event_to_state = {
event_id: group_to_state[group]
for event_id, group in iteritems(event_to_groups)
}
- defer.returnValue({event: event_to_state[event] for event in event_ids})
+ return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -867,7 +507,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], state_filter)
- defer.returnValue(state_map[event_id])
+ return state_map[event_id]
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -883,79 +523,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
- defer.returnValue(state_map[event_id])
-
- @cached(max_entries=50000)
- def _get_state_group_for_event(self, event_id):
- return self._simple_select_one_onecol(
- table="event_to_state_groups",
- keyvalues={"event_id": event_id},
- retcol="state_group",
- allow_none=True,
- desc="_get_state_group_for_event",
- )
-
- @cachedList(
- cached_method_name="_get_state_group_for_event",
- list_name="event_ids",
- num_args=1,
- inlineCallbacks=True,
- )
- def _get_state_group_for_events(self, event_ids):
- """Returns mapping event_id -> state_group
- """
- rows = yield self._simple_select_many_batch(
- table="event_to_state_groups",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=("event_id", "state_group"),
- desc="_get_state_group_for_events",
- )
-
- defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
-
- def _get_state_for_group_using_cache(self, cache, group, state_filter):
- """Checks if group is in cache. See `_get_state_for_groups`
-
- Args:
- cache(DictionaryCache): the state group cache to use
- group(int): The state group to lookup
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ return state_map[event_id]
- Returns 2-tuple (`state_dict`, `got_all`).
- `got_all` is a bool indicating if we successfully retrieved all
- requests state from the cache, if False we need to query the DB for the
- missing state.
- """
- is_all, known_absent, state_dict_ids = cache.get(group)
-
- if is_all or state_filter.is_full():
- # Either we have everything or want everything, either way
- # `is_all` tells us whether we've gotten everything.
- return state_filter.filter_state(state_dict_ids), is_all
-
- # tracks whether any of our requested types are missing from the cache
- missing_types = False
-
- if state_filter.has_wildcards():
- # We don't know if we fetched all the state keys for the types in
- # the filter that are wildcards, so we have to assume that we may
- # have missed some.
- missing_types = True
- else:
- # There aren't any wild cards, so `concrete_types()` returns the
- # complete list of event types we're wanting.
- for key in state_filter.concrete_types():
- if key not in state_dict_ids and key not in known_absent:
- missing_types = True
- break
-
- return state_filter.filter_state(state_dict_ids), not missing_types
-
- @defer.inlineCallbacks
- def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ def _get_state_for_groups(
+ self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ ):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@@ -965,157 +537,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
"""
-
- member_filter, non_member_filter = state_filter.get_member_split()
-
- # Now we look them up in the member and non-member caches
- non_member_state, incomplete_groups_nm, = (
- yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache, state_filter=non_member_filter
- )
- )
-
- member_state, incomplete_groups_m, = (
- yield self._get_state_for_groups_using_cache(
- groups, self._state_group_members_cache, state_filter=member_filter
- )
- )
-
- state = dict(non_member_state)
- for group in groups:
- state[group].update(member_state[group])
-
- # Now fetch any missing groups from the database
-
- incomplete_groups = incomplete_groups_m | incomplete_groups_nm
-
- if not incomplete_groups:
- defer.returnValue(state)
-
- cache_sequence_nm = self._state_group_cache.sequence
- cache_sequence_m = self._state_group_members_cache.sequence
-
- # Help the cache hit ratio by expanding the filter a bit
- db_state_filter = state_filter.return_expanded()
-
- group_to_state_dict = yield self._get_state_groups_from_groups(
- list(incomplete_groups), state_filter=db_state_filter
- )
-
- # Now lets update the caches
- self._insert_into_cache(
- group_to_state_dict,
- db_state_filter,
- cache_seq_num_members=cache_sequence_m,
- cache_seq_num_non_members=cache_sequence_nm,
- )
-
- # And finally update the result dict, by filtering out any extra
- # stuff we pulled out of the database.
- for group, group_state_dict in iteritems(group_to_state_dict):
- # We just replace any existing entries, as we will have loaded
- # everything we need from the database anyway.
- state[group] = state_filter.filter_state(group_state_dict)
-
- defer.returnValue(state)
-
- def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
- """Gets the state at each of a list of state groups, optionally
- filtering by type/state_key, querying from a specific cache.
-
- Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- cache (DictionaryCache): the cache of group ids to state dicts which
- we will pass through - either the normal state cache or the specific
- members state cache.
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- of entries in the cache, and the state group ids either missing
- from the cache or incomplete.
- """
- results = {}
- incomplete_groups = set()
- for group in set(groups):
- state_dict_ids, got_all = self._get_state_for_group_using_cache(
- cache, group, state_filter
- )
- results[group] = state_dict_ids
-
- if not got_all:
- incomplete_groups.add(group)
-
- return results, incomplete_groups
-
- def _insert_into_cache(
- self,
- group_to_state_dict,
- state_filter,
- cache_seq_num_members,
- cache_seq_num_non_members,
- ):
- """Inserts results from querying the database into the relevant cache.
-
- Args:
- group_to_state_dict (dict): The new entries pulled from database.
- Map from state group to state dict
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
- cache_seq_num_members (int): Sequence number of member cache since
- last lookup in cache
- cache_seq_num_non_members (int): Sequence number of member cache since
- last lookup in cache
- """
-
- # We need to work out which types we've fetched from the DB for the
- # member vs non-member caches. This should be as accurate as possible,
- # but can be an underestimate (e.g. when we have wild cards)
-
- member_filter, non_member_filter = state_filter.get_member_split()
- if member_filter.is_full():
- # We fetched all member events
- member_types = None
- else:
- # `concrete_types()` will only return a subset when there are wild
- # cards in the filter, but that's fine.
- member_types = member_filter.concrete_types()
-
- if non_member_filter.is_full():
- # We fetched all non member events
- non_member_types = None
- else:
- non_member_types = non_member_filter.concrete_types()
-
- for group, group_state_dict in iteritems(group_to_state_dict):
- state_dict_members = {}
- state_dict_non_members = {}
-
- for k, v in iteritems(group_state_dict):
- if k[0] == EventTypes.Member:
- state_dict_members[k] = v
- else:
- state_dict_non_members[k] = v
-
- self._state_group_members_cache.update(
- cache_seq_num_members,
- key=group,
- value=state_dict_members,
- fetched_keys=member_types,
- )
-
- self._state_group_cache.update(
- cache_seq_num_non_members,
- key=group,
- value=state_dict_non_members,
- fetched_keys=non_member_types,
- )
+ return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
@@ -1135,393 +559,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Deferred[int]: The state group ID
"""
-
- def _store_state_group_txn(txn):
- if current_state_ids is None:
- # AFAIK, this can never happen
- raise Exception("current_state_ids cannot be None")
-
- state_group = self.database_engine.get_next_state_group_id(txn)
-
- self._simple_insert_txn(
- txn,
- table="state_groups",
- values={"id": state_group, "room_id": room_id, "event_id": event_id},
- )
-
- # We persist as a delta if we can, while also ensuring the chain
- # of deltas isn't tooo long, as otherwise read performance degrades.
- if prev_group:
- is_in_db = self._simple_select_one_onecol_txn(
- txn,
- table="state_groups",
- keyvalues={"id": prev_group},
- retcol="id",
- allow_none=True,
- )
- if not is_in_db:
- raise Exception(
- "Trying to persist state with unpersisted prev_group: %r"
- % (prev_group,)
- )
-
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self._simple_insert_txn(
- txn,
- table="state_group_edges",
- values={"state_group": state_group, "prev_state_group": prev_group},
- )
-
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(delta_ids)
- ],
- )
- else:
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(current_state_ids)
- ],
- )
-
- # Prefill the state group caches with this group.
- # It's fine to use the sequence like this as the state group map
- # is immutable. (If the map wasn't immutable then this prefill could
- # race with another update)
-
- current_member_state_ids = {
- s: ev
- for (s, ev) in iteritems(current_state_ids)
- if s[0] == EventTypes.Member
- }
- txn.call_after(
- self._state_group_members_cache.update,
- self._state_group_members_cache.sequence,
- key=state_group,
- value=dict(current_member_state_ids),
- )
-
- current_non_member_state_ids = {
- s: ev
- for (s, ev) in iteritems(current_state_ids)
- if s[0] != EventTypes.Member
- }
- txn.call_after(
- self._state_group_cache.update,
- self._state_group_cache.sequence,
- key=state_group,
- value=dict(current_non_member_state_ids),
- )
-
- return state_group
-
- return self.runInteraction("store_state_group", _store_state_group_txn)
-
- def _count_state_group_hops_txn(self, txn, state_group):
- """Given a state group, count how many hops there are in the tree.
-
- This is used to ensure the delta chains don't get too long.
- """
- if isinstance(self.database_engine, PostgresEngine):
- sql = """
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT count(*) FROM state;
- """
-
- txn.execute(sql, (state_group,))
- row = txn.fetchone()
- if row and row[0]:
- return row[0]
- else:
- return 0
- else:
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
- count = 0
-
- while next_group:
- next_group = self._simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
- if next_group:
- count += 1
-
- return count
-
-
-class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
- """ Keeps track of the state at a given event.
-
- This is done by the concept of `state groups`. Every event is a assigned
- a state group (identified by an arbitrary string), which references a
- collection of state events. The current state of an event is then the
- collection of state events referenced by the event's state group.
-
- Hence, every change in the current state causes a new state group to be
- generated. However, if no change happens (e.g., if we get a message event
- with only one parent it inherits the state group from its parent.)
-
- There are three tables:
- * `state_groups`: Stores group name, first event with in the group and
- room id.
- * `event_to_state_groups`: Maps events to state groups.
- * `state_groups_state`: Maps state group to state events.
- """
-
- STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
- STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
- CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
- EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
-
- def __init__(self, db_conn, hs):
- super(StateStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
- self._background_deduplicate_state,
+ return self.stores.state.store_state_group(
+ event_id, room_id, prev_group, delta_ids, current_state_ids
)
- self.register_background_update_handler(
- self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
- )
- self.register_background_index_update(
- self.CURRENT_STATE_INDEX_UPDATE_NAME,
- index_name="current_state_events_member_index",
- table="current_state_events",
- columns=["state_key"],
- where_clause="type='m.room.member'",
- )
- self.register_background_index_update(
- self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
- index_name="event_to_state_groups_sg_index",
- table="event_to_state_groups",
- columns=["state_group"],
- )
-
- def _store_event_state_mappings_txn(self, txn, events_and_contexts):
- state_groups = {}
- for event, context in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
-
- # if the event was rejected, just give it the same state as its
- # predecessor.
- if context.rejected:
- state_groups[event.event_id] = context.prev_group
- continue
-
- state_groups[event.event_id] = context.state_group
-
- self._simple_insert_many_txn(
- txn,
- table="event_to_state_groups",
- values=[
- {"state_group": state_group_id, "event_id": event_id}
- for event_id, state_group_id in iteritems(state_groups)
- ],
- )
-
- for event_id, state_group_id in iteritems(state_groups):
- txn.call_after(
- self._get_state_group_for_event.prefill, (event_id,), state_group_id
- )
-
- @defer.inlineCallbacks
- def _background_deduplicate_state(self, progress, batch_size):
- """This background update will slowly deduplicate state by reencoding
- them as deltas.
- """
- last_state_group = progress.get("last_state_group", 0)
- rows_inserted = progress.get("rows_inserted", 0)
- max_group = progress.get("max_group", None)
-
- BATCH_SIZE_SCALE_FACTOR = 100
-
- batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
-
- if max_group is None:
- rows = yield self._execute(
- "_background_deduplicate_state",
- None,
- "SELECT coalesce(max(id), 0) FROM state_groups",
- )
- max_group = rows[0][0]
-
- def reindex_txn(txn):
- new_last_state_group = last_state_group
- for count in range(batch_size):
- txn.execute(
- "SELECT id, room_id FROM state_groups"
- " WHERE ? < id AND id <= ?"
- " ORDER BY id ASC"
- " LIMIT 1",
- (new_last_state_group, max_group),
- )
- row = txn.fetchone()
- if row:
- state_group, room_id = row
-
- if not row or not state_group:
- return True, count
-
- txn.execute(
- "SELECT state_group FROM state_group_edges"
- " WHERE state_group = ?",
- (state_group,),
- )
-
- # If we reach a point where we've already started inserting
- # edges we should stop.
- if txn.fetchall():
- return True, count
-
- txn.execute(
- "SELECT coalesce(max(id), 0) FROM state_groups"
- " WHERE id < ? AND room_id = ?",
- (state_group, room_id),
- )
- prev_group, = txn.fetchone()
- new_last_state_group = state_group
-
- if prev_group:
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if potential_hops >= MAX_STATE_DELTA_HOPS:
- # We want to ensure chains are at most this long,#
- # otherwise read performance degrades.
- continue
-
- prev_state = self._get_state_groups_from_groups_txn(
- txn, [prev_group]
- )
- prev_state = prev_state[prev_group]
-
- curr_state = self._get_state_groups_from_groups_txn(
- txn, [state_group]
- )
- curr_state = curr_state[state_group]
-
- if not set(prev_state.keys()) - set(curr_state.keys()):
- # We can only do a delta if the current has a strict super set
- # of keys
-
- delta_state = {
- key: value
- for key, value in iteritems(curr_state)
- if prev_state.get(key, None) != value
- }
-
- self._simple_delete_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": state_group},
- )
-
- self._simple_insert_txn(
- txn,
- table="state_group_edges",
- values={
- "state_group": state_group,
- "prev_state_group": prev_group,
- },
- )
-
- self._simple_delete_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- )
-
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(delta_state)
- ],
- )
-
- progress = {
- "last_state_group": state_group,
- "rows_inserted": rows_inserted + batch_size,
- "max_group": max_group,
- }
-
- self._background_update_progress_txn(
- txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
- )
-
- return False, batch_size
-
- finished, result = yield self.runInteraction(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
- )
-
- if finished:
- yield self._end_background_update(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
- )
-
- defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
-
- @defer.inlineCallbacks
- def _background_index_state(self, progress, batch_size):
- def reindex_txn(conn):
- conn.rollback()
- if isinstance(self.database_engine, PostgresEngine):
- # postgres insists on autocommit for the index
- conn.set_session(autocommit=True)
- try:
- txn = conn.cursor()
- txn.execute(
- "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
- " ON state_groups_state(state_group, type, state_key)"
- )
- txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
- finally:
- conn.set_session(autocommit=False)
- else:
- txn = conn.cursor()
- txn.execute(
- "CREATE INDEX state_groups_state_type_idx"
- " ON state_groups_state(state_group, type, state_key)"
- )
- txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
-
- yield self.runWithConnection(reindex_txn)
-
- yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
-
- defer.returnValue(1)
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
deleted file mode 100644
index ff266b09b0..0000000000
--- a/synapse/storage/stats.py
+++ /dev/null
@@ -1,468 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018, 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes, Membership
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.state_deltas import StateDeltasStore
-from synapse.util.caches.descriptors import cached
-
-logger = logging.getLogger(__name__)
-
-# these fields track absolutes (e.g. total number of rooms on the server)
-ABSOLUTE_STATS_FIELDS = {
- "room": (
- "current_state_events",
- "joined_members",
- "invited_members",
- "left_members",
- "banned_members",
- "state_events",
- ),
- "user": ("public_rooms", "private_rooms"),
-}
-
-TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")}
-
-TEMP_TABLE = "_temp_populate_stats"
-
-
-class StatsStore(StateDeltasStore):
- def __init__(self, db_conn, hs):
- super(StatsStore, self).__init__(db_conn, hs)
-
- self.server_name = hs.hostname
- self.clock = self.hs.get_clock()
- self.stats_enabled = hs.config.stats_enabled
- self.stats_bucket_size = hs.config.stats_bucket_size
-
- self.register_background_update_handler(
- "populate_stats_createtables", self._populate_stats_createtables
- )
- self.register_background_update_handler(
- "populate_stats_process_rooms", self._populate_stats_process_rooms
- )
- self.register_background_update_handler(
- "populate_stats_cleanup", self._populate_stats_cleanup
- )
-
- @defer.inlineCallbacks
- def _populate_stats_createtables(self, progress, batch_size):
-
- if not self.stats_enabled:
- yield self._end_background_update("populate_stats_createtables")
- defer.returnValue(1)
-
- # Get all the rooms that we want to process.
- def _make_staging_area(txn):
- # Create the temporary tables
- stmts = get_statements("""
- -- We just recreate the table, we'll be reinserting the
- -- correct entries again later anyway.
- DROP TABLE IF EXISTS {temp}_rooms;
-
- CREATE TABLE IF NOT EXISTS {temp}_rooms(
- room_id TEXT NOT NULL,
- events BIGINT NOT NULL
- );
-
- CREATE INDEX {temp}_rooms_events
- ON {temp}_rooms(events);
- CREATE INDEX {temp}_rooms_id
- ON {temp}_rooms(room_id);
- """.format(temp=TEMP_TABLE).splitlines())
-
- for statement in stmts:
- txn.execute(statement)
-
- sql = (
- "CREATE TABLE IF NOT EXISTS "
- + TEMP_TABLE
- + "_position(position TEXT NOT NULL)"
- )
- txn.execute(sql)
-
- # Get rooms we want to process from the database, only adding
- # those that we haven't (i.e. those not in room_stats_earliest_token)
- sql = """
- INSERT INTO %s_rooms (room_id, events)
- SELECT c.room_id, count(*) FROM current_state_events AS c
- LEFT JOIN room_stats_earliest_token AS t USING (room_id)
- WHERE t.room_id IS NULL
- GROUP BY c.room_id
- """ % (TEMP_TABLE,)
- txn.execute(sql)
-
- new_pos = yield self.get_max_stream_id_in_current_state_deltas()
- yield self.runInteraction("populate_stats_temp_build", _make_staging_area)
- yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
- self.get_earliest_token_for_room_stats.invalidate_all()
-
- yield self._end_background_update("populate_stats_createtables")
- defer.returnValue(1)
-
- @defer.inlineCallbacks
- def _populate_stats_cleanup(self, progress, batch_size):
- """
- Update the user directory stream position, then clean up the old tables.
- """
- if not self.stats_enabled:
- yield self._end_background_update("populate_stats_cleanup")
- defer.returnValue(1)
-
- position = yield self._simple_select_one_onecol(
- TEMP_TABLE + "_position", None, "position"
- )
- yield self.update_stats_stream_pos(position)
-
- def _delete_staging_area(txn):
- txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
- txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
-
- yield self.runInteraction("populate_stats_cleanup", _delete_staging_area)
-
- yield self._end_background_update("populate_stats_cleanup")
- defer.returnValue(1)
-
- @defer.inlineCallbacks
- def _populate_stats_process_rooms(self, progress, batch_size):
-
- if not self.stats_enabled:
- yield self._end_background_update("populate_stats_process_rooms")
- defer.returnValue(1)
-
- # If we don't have progress filed, delete everything.
- if not progress:
- yield self.delete_all_stats()
-
- def _get_next_batch(txn):
- # Only fetch 250 rooms, so we don't fetch too many at once, even
- # if those 250 rooms have less than batch_size state events.
- sql = """
- SELECT room_id, events FROM %s_rooms
- ORDER BY events DESC
- LIMIT 250
- """ % (
- TEMP_TABLE,
- )
- txn.execute(sql)
- rooms_to_work_on = txn.fetchall()
-
- if not rooms_to_work_on:
- return None
-
- # Get how many are left to process, so we can give status on how
- # far we are in processing
- txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
- progress["remaining"] = txn.fetchone()[0]
-
- return rooms_to_work_on
-
- rooms_to_work_on = yield self.runInteraction(
- "populate_stats_temp_read", _get_next_batch
- )
-
- # No more rooms -- complete the transaction.
- if not rooms_to_work_on:
- yield self._end_background_update("populate_stats_process_rooms")
- defer.returnValue(1)
-
- logger.info(
- "Processing the next %d rooms of %d remaining",
- len(rooms_to_work_on), progress["remaining"],
- )
-
- # Number of state events we've processed by going through each room
- processed_event_count = 0
-
- for room_id, event_count in rooms_to_work_on:
-
- current_state_ids = yield self.get_current_state_ids(room_id)
-
- join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
- history_visibility_id = current_state_ids.get(
- (EventTypes.RoomHistoryVisibility, "")
- )
- encryption_id = current_state_ids.get((EventTypes.RoomEncryption, ""))
- name_id = current_state_ids.get((EventTypes.Name, ""))
- topic_id = current_state_ids.get((EventTypes.Topic, ""))
- avatar_id = current_state_ids.get((EventTypes.RoomAvatar, ""))
- canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, ""))
-
- state_events = yield self.get_events([
- join_rules_id, history_visibility_id, encryption_id, name_id,
- topic_id, avatar_id, canonical_alias_id,
- ])
-
- def _get_or_none(event_id, arg):
- event = state_events.get(event_id)
- if event:
- return event.content.get(arg)
- return None
-
- yield self.update_room_state(
- room_id,
- {
- "join_rules": _get_or_none(join_rules_id, "join_rule"),
- "history_visibility": _get_or_none(
- history_visibility_id, "history_visibility"
- ),
- "encryption": _get_or_none(encryption_id, "algorithm"),
- "name": _get_or_none(name_id, "name"),
- "topic": _get_or_none(topic_id, "topic"),
- "avatar": _get_or_none(avatar_id, "url"),
- "canonical_alias": _get_or_none(canonical_alias_id, "alias"),
- },
- )
-
- now = self.hs.get_reactor().seconds()
-
- # quantise time to the nearest bucket
- now = (now // self.stats_bucket_size) * self.stats_bucket_size
-
- def _fetch_data(txn):
-
- # Get the current token of the room
- current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
-
- current_state_events = len(current_state_ids)
-
- membership_counts = self._get_user_counts_in_room_txn(txn, room_id)
-
- total_state_events = self._get_total_state_event_counts_txn(
- txn, room_id
- )
-
- self._update_stats_txn(
- txn,
- "room",
- room_id,
- now,
- {
- "bucket_size": self.stats_bucket_size,
- "current_state_events": current_state_events,
- "joined_members": membership_counts.get(Membership.JOIN, 0),
- "invited_members": membership_counts.get(Membership.INVITE, 0),
- "left_members": membership_counts.get(Membership.LEAVE, 0),
- "banned_members": membership_counts.get(Membership.BAN, 0),
- "state_events": total_state_events,
- },
- )
- self._simple_insert_txn(
- txn,
- "room_stats_earliest_token",
- {"room_id": room_id, "token": current_token},
- )
-
- # We've finished a room. Delete it from the table.
- self._simple_delete_one_txn(
- txn, TEMP_TABLE + "_rooms", {"room_id": room_id},
- )
-
- yield self.runInteraction("update_room_stats", _fetch_data)
-
- # Update the remaining counter.
- progress["remaining"] -= 1
- yield self.runInteraction(
- "populate_stats",
- self._background_update_progress_txn,
- "populate_stats_process_rooms",
- progress,
- )
-
- processed_event_count += event_count
-
- if processed_event_count > batch_size:
- # Don't process any more rooms, we've hit our batch size.
- defer.returnValue(processed_event_count)
-
- defer.returnValue(processed_event_count)
-
- def delete_all_stats(self):
- """
- Delete all statistics records.
- """
-
- def _delete_all_stats_txn(txn):
- txn.execute("DELETE FROM room_state")
- txn.execute("DELETE FROM room_stats")
- txn.execute("DELETE FROM room_stats_earliest_token")
- txn.execute("DELETE FROM user_stats")
-
- return self.runInteraction("delete_all_stats", _delete_all_stats_txn)
-
- def get_stats_stream_pos(self):
- return self._simple_select_one_onecol(
- table="stats_stream_pos",
- keyvalues={},
- retcol="stream_id",
- desc="stats_stream_pos",
- )
-
- def update_stats_stream_pos(self, stream_id):
- return self._simple_update_one(
- table="stats_stream_pos",
- keyvalues={},
- updatevalues={"stream_id": stream_id},
- desc="update_stats_stream_pos",
- )
-
- def update_room_state(self, room_id, fields):
- """
- Args:
- room_id (str)
- fields (dict[str:Any])
- """
-
- # For whatever reason some of the fields may contain null bytes, which
- # postgres isn't a fan of, so we replace those fields with null.
- for col in (
- "join_rules",
- "history_visibility",
- "encryption",
- "name",
- "topic",
- "avatar",
- "canonical_alias"
- ):
- field = fields.get(col)
- if field and "\0" in field:
- fields[col] = None
-
- return self._simple_upsert(
- table="room_state",
- keyvalues={"room_id": room_id},
- values=fields,
- desc="update_room_state",
- )
-
- def get_deltas_for_room(self, room_id, start, size=100):
- """
- Get statistics deltas for a given room.
-
- Args:
- room_id (str)
- start (int): Pagination start. Number of entries, not timestamp.
- size (int): How many entries to return.
-
- Returns:
- Deferred[list[dict]], where the dict has the keys of
- ABSOLUTE_STATS_FIELDS["room"] and "ts".
- """
- return self._simple_select_list_paginate(
- "room_stats",
- {"room_id": room_id},
- "ts",
- start,
- size,
- retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]),
- order_direction="DESC",
- )
-
- def get_all_room_state(self):
- return self._simple_select_list(
- "room_state", None, retcols=("name", "topic", "canonical_alias")
- )
-
- @cached()
- def get_earliest_token_for_room_stats(self, room_id):
- """
- Fetch the "earliest token". This is used by the room stats delta
- processor to ignore deltas that have been processed between the
- start of the background task and any particular room's stats
- being calculated.
-
- Returns:
- Deferred[int]
- """
- return self._simple_select_one_onecol(
- "room_stats_earliest_token",
- {"room_id": room_id},
- retcol="token",
- allow_none=True,
- )
-
- def update_stats(self, stats_type, stats_id, ts, fields):
- table, id_col = TYPE_TO_ROOM[stats_type]
- return self._simple_upsert(
- table=table,
- keyvalues={id_col: stats_id, "ts": ts},
- values=fields,
- desc="update_stats",
- )
-
- def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields):
- table, id_col = TYPE_TO_ROOM[stats_type]
- return self._simple_upsert_txn(
- txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields
- )
-
- def update_stats_delta(self, ts, stats_type, stats_id, field, value):
- def _update_stats_delta(txn):
- table, id_col = TYPE_TO_ROOM[stats_type]
-
- sql = (
- "SELECT * FROM %s"
- " WHERE %s=? and ts=("
- " SELECT MAX(ts) FROM %s"
- " WHERE %s=?"
- ")"
- ) % (table, id_col, table, id_col)
- txn.execute(sql, (stats_id, stats_id))
- rows = self.cursor_to_dict(txn)
- if len(rows) == 0:
- # silently skip as we don't have anything to apply a delta to yet.
- # this tries to minimise any race between the initial sync and
- # subsequent deltas arriving.
- return
-
- current_ts = ts
- latest_ts = rows[0]["ts"]
- if current_ts < latest_ts:
- # This one is in the past, but we're just encountering it now.
- # Mark it as part of the current bucket.
- current_ts = latest_ts
- elif ts != latest_ts:
- # we have to copy our absolute counters over to the new entry.
- values = {
- key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type]
- }
- values[id_col] = stats_id
- values["ts"] = ts
- values["bucket_size"] = self.stats_bucket_size
-
- self._simple_insert_txn(txn, table=table, values=values)
-
- # actually update the new value
- if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]:
- self._simple_update_txn(
- txn,
- table=table,
- keyvalues={id_col: stats_id, "ts": current_ts},
- updatevalues={field: value},
- )
- else:
- sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % (
- table,
- field,
- field,
- id_col,
- )
- txn.execute(sql, (value, stats_id, current_ts))
-
- return self.runInteraction("update_stats_delta", _update_stats_delta)
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
new file mode 100644
index 0000000000..daff81c5ee
--- /dev/null
+++ b/synapse/storage/types.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Iterable, Iterator, List, Tuple
+
+from typing_extensions import Protocol
+
+
+"""
+Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
+"""
+
+
+class Cursor(Protocol):
+ def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+ ...
+
+ def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+ ...
+
+ def fetchall(self) -> List[Tuple]:
+ ...
+
+ def fetchone(self) -> Tuple:
+ ...
+
+ @property
+ def description(self) -> Any:
+ return None
+
+ @property
+ def rowcount(self) -> int:
+ return 0
+
+ def __iter__(self) -> Iterator[Tuple]:
+ ...
+
+ def close(self) -> None:
+ ...
+
+
+class Connection(Protocol):
+ def cursor(self) -> Cursor:
+ ...
+
+ def close(self) -> None:
+ ...
+
+ def commit(self) -> None:
+ ...
+
+ def rollback(self, *args, **kwargs) -> None:
+ ...
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f1c8d99419..9d851beaa5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1):
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
- val, = cur.fetchone()
+ (val,) = cur.fetchone()
cur.close()
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
@@ -195,6 +195,6 @@ class ChainedIdGenerator(object):
with self._lock:
if self._unfinished_ids:
stream_id, chained_id = self._unfinished_ids[0]
- return (stream_id - 1, chained_id)
+ return stream_id - 1, chained_id
- return (self._current_max, self.chained_generator.get_current_token())
+ return self._current_max, self.chained_generator.get_current_token()
|