diff options
Diffstat (limited to 'synapse/storage')
77 files changed, 8585 insertions, 2618 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index d604e7668f..8cdfd50f90 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,13 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer - from synapse.storage.devices import DeviceStore from .appservice import ( ApplicationServiceStore, ApplicationServiceTransactionStore ) -from ._base import LoggingTransaction from .directory import DirectoryStore from .events import EventsStore from .presence import PresenceStore, UserPresenceState @@ -37,7 +35,7 @@ from .media_repository import MediaRepositoryStore from .rejections import RejectionsStore from .event_push_actions import EventPushActionsStore from .deviceinbox import DeviceInboxStore - +from .group_server import GroupServerStore from .state import StateStore from .signatures import SignatureStore from .filtering import FilteringStore @@ -49,6 +47,7 @@ from .tags import TagsStore from .account_data import AccountDataStore from .openid import OpenIdStore from .client_ips import ClientIpStore +from .user_directory import UserDirectoryStore from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator from .engines import PostgresEngine @@ -86,6 +85,8 @@ class DataStore(RoomMemberStore, RoomStore, ClientIpStore, DeviceStore, DeviceInboxStore, + UserDirectoryStore, + GroupServerStore, ): def __init__(self, db_conn, hs): @@ -101,12 +102,6 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "events", "stream_ordering", step=-1, extra_tables=[("ex_outlier_stream", "event_stream_ordering")] ) - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" - ) - self._account_data_id_gen = StreamIdGenerator( - db_conn, "account_data_max_stream_id", "stream_id" - ) self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) @@ -121,7 +116,6 @@ class DataStore(RoomMemberStore, RoomStore, ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") - self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") @@ -133,6 +127,9 @@ class DataStore(RoomMemberStore, RoomStore, db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")], ) + self._group_updates_id_gen = StreamIdGenerator( + db_conn, "local_group_updates", "stream_id", + ) if isinstance(self.database_engine, PostgresEngine): self._cache_id_gen = StreamIdGenerator( @@ -141,27 +138,6 @@ class DataStore(RoomMemberStore, RoomStore, else: self._cache_id_gen = None - events_max = self._stream_id_gen.get_current_token() - event_cache_prefill, min_event_val = self._get_cache_dict( - db_conn, "events", - entity_column="room_id", - stream_column="stream_ordering", - max_value=events_max, - ) - self._events_stream_cache = StreamChangeCache( - "EventsRoomStreamChangeCache", min_event_val, - prefilled_cache=event_cache_prefill, - ) - - self._membership_stream_cache = StreamChangeCache( - "MembershipStreamChangeCache", events_max, - ) - - account_max = self._account_data_id_gen.get_current_token() - self._account_data_stream_cache = StreamChangeCache( - "AccountDataAndTagsChangeCache", account_max, - ) - self._presence_on_startup = self._get_active_presence(db_conn) presence_cache_prefill, min_presence_val = self._get_cache_dict( @@ -175,18 +151,6 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=presence_cache_prefill ) - push_rules_prefill, push_rules_id = self._get_cache_dict( - db_conn, "push_rules_stream", - entity_column="user_id", - stream_column="stream_id", - max_value=self._push_rules_stream_id_gen.get_current_token()[0], - ) - - self.push_rules_stream_cache = StreamChangeCache( - "PushRulesStreamChangeCache", push_rules_id, - prefilled_cache=push_rules_prefill, - ) - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() device_inbox_prefill, min_device_inbox_id = self._get_cache_dict( db_conn, "device_inbox", @@ -221,23 +185,35 @@ class DataStore(RoomMemberStore, RoomStore, "DeviceListFederationStreamChangeCache", device_list_max, ) - cur = LoggingTransaction( - db_conn.cursor(), - name="_find_stream_orderings_for_times_txn", - database_engine=self.database_engine, - after_callbacks=[] + events_max = self._stream_id_gen.get_current_token() + curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict( + db_conn, "current_state_delta_stream", + entity_column="room_id", + stream_column="stream_id", + max_value=events_max, # As we share the stream id with events token + limit=1000, + ) + self._curr_state_delta_stream_cache = StreamChangeCache( + "_curr_state_delta_stream_cache", min_curr_state_delta_id, + prefilled_cache=curr_state_delta_prefill, ) - self._find_stream_orderings_for_times_txn(cur) - cur.close() - self.find_stream_orderings_looping_call = self._clock.looping_call( - self._find_stream_orderings_for_times, 60 * 60 * 1000 + _group_updates_prefill, min_group_updates_id = self._get_cache_dict( + db_conn, "local_group_updates", + entity_column="user_id", + stream_column="stream_id", + max_value=self._group_updates_id_gen.get_current_token(), + limit=1000, + ) + self._group_updates_stream_cache = StreamChangeCache( + "_group_updates_stream_cache", min_group_updates_id, + prefilled_cache=_group_updates_prefill, ) self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() - super(DataStore, self).__init__(hs) + super(DataStore, self).__init__(db_conn, hs) def take_presence_startup_info(self): active_on_startup = self._presence_on_startup @@ -266,36 +242,110 @@ class DataStore(RoomMemberStore, RoomStore, return [UserPresenceState(**row) for row in rows] - @defer.inlineCallbacks def count_daily_users(self): """ Counts the number of users who used this homeserver in the last 24 hours. """ def _count_users(txn): - txn.execute( - "SELECT COUNT(DISTINCT user_id) AS users" - " FROM user_ips" - " WHERE last_seen > ?", - # This is close enough to a day for our purposes. - (int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),) - ) - rows = self.cursor_to_dict(txn) - if rows: - return rows[0]["users"] - return 0 + yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - ret = yield self.runInteraction("count_users", _count_users) - defer.returnValue(ret) + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT user_id FROM user_ips + WHERE last_seen > ? + GROUP BY user_id + ) u + """ - def get_user_ip_and_agents(self, user): - return self._simple_select_list( - table="user_ips", - keyvalues={"user_id": user.to_string()}, - retcols=[ - "access_token", "ip", "user_agent", "last_seen" - ], - desc="get_user_ip_and_agents", - ) + txn.execute(sql, (yesterday,)) + count, = txn.fetchone() + return count + + return self.runInteraction("count_users", _count_users) + + def count_r30_users(self): + """ + Counts the number of 30 day retained users, defined as:- + * Users who have created their accounts more than 30 days ago + * Where last seen at most 30 days ago + * Where account creation and last_seen are > 30 days apart + + Returns counts globaly for a given user as well as breaking + by platform + """ + def _count_r30_users(txn): + thirty_days_in_secs = 86400 * 30 + now = int(self._clock.time()) + thirty_days_ago_in_secs = now - thirty_days_in_secs + + sql = """ + SELECT platform, COALESCE(count(*), 0) FROM ( + SELECT + users.name, platform, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen, + CASE + WHEN user_agent LIKE '%%Android%%' THEN 'android' + WHEN user_agent LIKE '%%iOS%%' THEN 'ios' + WHEN user_agent LIKE '%%Electron%%' THEN 'electron' + WHEN user_agent LIKE '%%Mozilla%%' THEN 'web' + WHEN user_agent LIKE '%%Gecko%%' THEN 'web' + ELSE 'unknown' + END + AS platform + FROM user_ips + ) uip + ON users.name = uip.user_id + AND users.appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, platform, users.creation_ts + ) u GROUP BY platform + """ + + results = {} + txn.execute(sql, (thirty_days_ago_in_secs, + thirty_days_ago_in_secs)) + + for row in txn: + if row[0] is 'unknown': + pass + results[row[0]] = row[1] + + sql = """ + SELECT COALESCE(count(*), 0) FROM ( + SELECT users.name, users.creation_ts * 1000, + MAX(uip.last_seen) + FROM users + INNER JOIN ( + SELECT + user_id, + last_seen + FROM user_ips + ) uip + ON users.name = uip.user_id + AND appservice_id is NULL + AND users.creation_ts < ? + AND uip.last_seen/1000 > ? + AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30 + GROUP BY users.name, users.creation_ts + ) u + """ + + txn.execute(sql, (thirty_days_ago_in_secs, + thirty_days_ago_in_secs)) + + count, = txn.fetchone() + results['all'] = count + + return results + + return self.runInteraction("count_r30_users", _count_r30_users) def get_users(self): """Function to reterive a list of users in users table. diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index b0dc391190..2262776ab2 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -16,9 +16,7 @@ import logging from synapse.api.errors import StoreError from synapse.util.logcontext import LoggingContext, PreserveLoggingContext -from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache -from synapse.util.caches import intern_dict from synapse.storage.engines import PostgresEngine import synapse.metrics @@ -28,10 +26,6 @@ from twisted.internet import defer import sys import time import threading -import os - - -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1)) logger = logging.getLogger(__name__) @@ -53,20 +47,27 @@ class LoggingTransaction(object): """An object that almost-transparently proxies for the 'txn' object passed to the constructor. Adds logging and metrics to the .execute() method.""" - __slots__ = ["txn", "name", "database_engine", "after_callbacks"] + __slots__ = [ + "txn", "name", "database_engine", "after_callbacks", "exception_callbacks", + ] - def __init__(self, txn, name, database_engine, after_callbacks): + def __init__(self, txn, name, database_engine, after_callbacks, + exception_callbacks): object.__setattr__(self, "txn", txn) object.__setattr__(self, "name", name) object.__setattr__(self, "database_engine", database_engine) object.__setattr__(self, "after_callbacks", after_callbacks) + object.__setattr__(self, "exception_callbacks", exception_callbacks) - def call_after(self, callback, *args): + def call_after(self, callback, *args, **kwargs): """Call the given callback on the main twisted thread after the transaction has finished. Used to invalidate the caches on the correct thread. """ - self.after_callbacks.append((callback, args)) + self.after_callbacks.append((callback, args, kwargs)) + + def call_on_exception(self, callback, *args, **kwargs): + self.exception_callbacks.append((callback, args, kwargs)) def __getattr__(self, name): return getattr(self.txn, name) @@ -74,13 +75,22 @@ class LoggingTransaction(object): def __setattr__(self, name, value): setattr(self.txn, name, value) + def __iter__(self): + return self.txn.__iter__() + def execute(self, sql, *args): self._do_execute(self.txn.execute, sql, *args) def executemany(self, sql, *args): self._do_execute(self.txn.executemany, sql, *args) + def _make_sql_one_line(self, sql): + "Strip newlines out of SQL so that the loggers in the DB are on one line" + return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + def _do_execute(self, func, sql, *args): + sql = self._make_sql_one_line(sql) + # TODO(paul): Maybe use 'info' and 'debug' for values? sql_logger.debug("[SQL] {%s} %s", self.name, sql) @@ -91,7 +101,7 @@ class LoggingTransaction(object): "[SQL values] {%s} %r", self.name, args[0] ) - except: + except Exception: # Don't let logging failures stop SQL from working pass @@ -127,7 +137,7 @@ class PerformanceCounters(object): def interval(self, interval_duration, limit=3): counters = [] - for name, (count, cum_time) in self.current_counters.items(): + for name, (count, cum_time) in self.current_counters.iteritems(): prev_count, prev_time = self.previous_counters.get(name, (0, 0)) counters.append(( (cum_time - prev_time) / interval_duration, @@ -150,7 +160,7 @@ class PerformanceCounters(object): class SQLBaseStore(object): _TXN_ID = 0 - def __init__(self, hs): + def __init__(self, db_conn, hs): self.hs = hs self._clock = hs.get_clock() self._db_pool = hs.get_db_pool() @@ -168,10 +178,6 @@ class SQLBaseStore(object): self._get_event_cache = Cache("*getEvent*", keylen=3, max_entries=hs.config.event_cache_size) - self._state_group_cache = DictionaryCache( - "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR - ) - self._event_fetch_lock = threading.Condition() self._event_fetch_list = [] self._event_fetch_ongoing = 0 @@ -209,8 +215,8 @@ class SQLBaseStore(object): self._clock.looping_call(loop, 10000) - def _new_transaction(self, conn, desc, after_callbacks, logging_context, - func, *args, **kwargs): + def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks, + logging_context, func, *args, **kwargs): start = time.time() * 1000 txn_id = self._TXN_ID @@ -229,7 +235,8 @@ class SQLBaseStore(object): try: txn = conn.cursor() txn = LoggingTransaction( - txn, name, self.database_engine, after_callbacks + txn, name, self.database_engine, after_callbacks, + exception_callbacks, ) r = func(txn, *args, **kwargs) conn.commit() @@ -284,47 +291,66 @@ class SQLBaseStore(object): @defer.inlineCallbacks def runInteraction(self, desc, func, *args, **kwargs): - """Wraps the .runInteraction() method on the underlying db_pool.""" - current_context = LoggingContext.current_context() + """Starts a transaction on the database and runs a given function - start_time = time.time() * 1000 + Arguments: + desc (str): description of the transaction, for logging and metrics + func (func): callback function, which will be called with a + database transaction (twisted.enterprise.adbapi.Transaction) as + its first argument, followed by `args` and `kwargs`. + + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ + current_context = LoggingContext.current_context() after_callbacks = [] + exception_callbacks = [] def inner_func(conn, *args, **kwargs): - with LoggingContext("runInteraction") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + return self._new_transaction( + conn, desc, after_callbacks, exception_callbacks, current_context, + func, *args, **kwargs + ) - if self.database_engine.is_connection_closed(conn): - logger.debug("Reconnecting closed database connection") - conn.reconnect() + try: + result = yield self.runWithConnection(inner_func, *args, **kwargs) - current_context.copy_to(context) - return self._new_transaction( - conn, desc, after_callbacks, current_context, - func, *args, **kwargs - ) + for after_callback, after_args, after_kwargs in after_callbacks: + after_callback(*after_args, **after_kwargs) + except: # noqa: E722, as we reraise the exception this is fine. + for after_callback, after_args, after_kwargs in exception_callbacks: + after_callback(*after_args, **after_kwargs) + raise - try: - with PreserveLoggingContext(): - result = yield self._db_pool.runWithConnection( - inner_func, *args, **kwargs - ) - finally: - for after_callback, after_args in after_callbacks: - after_callback(*after_args) defer.returnValue(result) @defer.inlineCallbacks def runWithConnection(self, func, *args, **kwargs): - """Wraps the .runInteraction() method on the underlying db_pool.""" + """Wraps the .runWithConnection() method on the underlying db_pool. + + Arguments: + func (func): callback function, which will be called with a + database connection (twisted.enterprise.adbapi.Connection) as + its first argument, followed by `args` and `kwargs`. + args (list): positional args to pass to `func` + kwargs (dict): named args to pass to `func` + + Returns: + Deferred: The result of func + """ current_context = LoggingContext.current_context() start_time = time.time() * 1000 def inner_func(conn, *args, **kwargs): with LoggingContext("runWithConnection") as context: - sql_scheduling_timer.inc_by(time.time() * 1000 - start_time) + sched_duration_ms = time.time() * 1000 - start_time + sql_scheduling_timer.inc_by(sched_duration_ms) + current_context.add_database_scheduled(sched_duration_ms) if self.database_engine.is_connection_closed(conn): logger.debug("Reconnecting closed database connection") @@ -350,9 +376,9 @@ class SQLBaseStore(object): Returns: A list of dicts where the key is the column header. """ - col_headers = list(column[0] for column in cursor.description) + col_headers = list(intern(str(column[0])) for column in cursor.description) results = list( - intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall() + dict(zip(col_headers, row)) for row in cursor ) return results @@ -417,6 +443,11 @@ class SQLBaseStore(object): txn.execute(sql, vals) + def _simple_insert_many(self, table, values, desc): + return self.runInteraction( + desc, self._simple_insert_many_txn, table, values + ) + @staticmethod def _simple_insert_many_txn(txn, table, values): if not values: @@ -452,23 +483,53 @@ class SQLBaseStore(object): txn.executemany(sql, vals) + @defer.inlineCallbacks def _simple_upsert(self, table, keyvalues, values, insertion_values={}, desc="_simple_upsert", lock=True): """ + + `lock` should generally be set to True (the default), but can be set + to False if either of the following are true: + + * there is a UNIQUE INDEX on the key columns. In this case a conflict + will cause an IntegrityError in which case this function will retry + the update. + + * we somehow know that we are the only thread which will be updating + this table. + Args: table (str): The table to upsert into keyvalues (dict): The unique key tables and their new values values (dict): The nonunique columns and their new values - insertion_values (dict): key/values to use when inserting + insertion_values (dict): additional key/values to use only when + inserting + lock (bool): True to lock the table when doing the upsert. Returns: Deferred(bool): True if a new entry was created, False if an existing one was updated. """ - return self.runInteraction( - desc, - self._simple_upsert_txn, table, keyvalues, values, insertion_values, - lock - ) + attempts = 0 + while True: + try: + result = yield self.runInteraction( + desc, + self._simple_upsert_txn, table, keyvalues, values, insertion_values, + lock=lock + ) + defer.returnValue(result) + except self.database_engine.module.IntegrityError as e: + attempts += 1 + if attempts >= 5: + # don't retry forever, because things other than races + # can cause IntegrityErrors + raise + + # presumably we raced with another transaction: let's retry. + logger.warn( + "IntegrityError when upserting into %s; retrying: %s", + table, e + ) def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={}, lock=True): @@ -476,45 +537,38 @@ class SQLBaseStore(object): if lock: self.database_engine.lock_table(txn, table) - # Try to update + # First try to update. sql = "UPDATE %s SET %s WHERE %s" % ( table, ", ".join("%s = ?" % (k,) for k in values), " AND ".join("%s = ?" % (k,) for k in keyvalues) ) sqlargs = values.values() + keyvalues.values() - logger.debug( - "[SQL] %s Args=%s", - sql, sqlargs, - ) txn.execute(sql, sqlargs) - if txn.rowcount == 0: - # We didn't update and rows so insert a new one - allvalues = {} - allvalues.update(keyvalues) - allvalues.update(values) - allvalues.update(insertion_values) + if txn.rowcount > 0: + # successfully updated at least one row. + return False - sql = "INSERT INTO %s (%s) VALUES (%s)" % ( - table, - ", ".join(k for k in allvalues), - ", ".join("?" for _ in allvalues) - ) - logger.debug( - "[SQL] %s Args=%s", - sql, keyvalues.values(), - ) - txn.execute(sql, allvalues.values()) + # We didn't update any rows so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + allvalues.update(insertion_values) - return True - else: - return False + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues) + ) + txn.execute(sql, allvalues.values()) + # successfully inserted + return True def _simple_select_one(self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"): """Executes a SELECT query on the named table, which is expected to - return a single row, returning a single column from it. + return a single row, returning multiple columns from it. Args: table : string giving the table name @@ -567,22 +621,20 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) - else: - where = "" - sql = ( - "SELECT %(retcol)s FROM %(table)s %(where)s" + "SELECT %(retcol)s FROM %(table)s" ) % { "retcol": retcol, "table": table, - "where": where, } - txn.execute(sql, keyvalues.values()) + if keyvalues: + sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) + txn.execute(sql, keyvalues.values()) + else: + txn.execute(sql) - return [r[0] for r in txn.fetchall()] + return [r[0] for r in txn] def _simple_select_onecol(self, table, keyvalues, retcol, desc="_simple_select_onecol"): @@ -591,7 +643,7 @@ class SQLBaseStore(object): Args: table (str): table name - keyvalues (dict): column names and values to select the rows with + keyvalues (dict|None): column names and values to select the rows with retcol (str): column whos value we wish to retrieve. Returns: @@ -715,7 +767,7 @@ class SQLBaseStore(object): ) values.extend(iterable) - for key, value in keyvalues.items(): + for key, value in keyvalues.iteritems(): clauses.append("%s = ?" % (key,)) values.append(value) @@ -728,6 +780,33 @@ class SQLBaseStore(object): txn.execute(sql, values) return cls.cursor_to_dict(txn) + def _simple_update(self, table, keyvalues, updatevalues, desc): + return self.runInteraction( + desc, + self._simple_update_txn, + table, keyvalues, updatevalues, + ) + + @staticmethod + def _simple_update_txn(txn, table, keyvalues, updatevalues): + if keyvalues: + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) + else: + where = "" + + update_sql = "UPDATE %s SET %s %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in updatevalues), + where, + ) + + txn.execute( + update_sql, + updatevalues.values() + keyvalues.values() + ) + + return txn.rowcount + def _simple_update_one(self, table, keyvalues, updatevalues, desc="_simple_update_one"): """Executes an UPDATE query on the named table, setting new values for @@ -753,27 +832,13 @@ class SQLBaseStore(object): table, keyvalues, updatevalues, ) - @staticmethod - def _simple_update_one_txn(txn, table, keyvalues, updatevalues): - if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) - else: - where = "" - - update_sql = "UPDATE %s SET %s %s" % ( - table, - ", ".join("%s = ?" % (k,) for k in updatevalues), - where, - ) - - txn.execute( - update_sql, - updatevalues.values() + keyvalues.values() - ) + @classmethod + def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues): + rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues) - if txn.rowcount == 0: + if rowcount == 0: raise StoreError(404, "No row found") - if txn.rowcount > 1: + if rowcount > 1: raise StoreError(500, "More than one row matched") @staticmethod @@ -843,6 +908,47 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) + def _simple_delete_many(self, table, column, iterable, keyvalues, desc): + return self.runInteraction( + desc, self._simple_delete_many_txn, table, column, iterable, keyvalues + ) + + @staticmethod + def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + """ + if not iterable: + return + + sql = "DELETE FROM %s" % table + + clauses = [] + values = [] + clauses.append( + "%s IN (%s)" % (column, ",".join("?" for _ in iterable)) + ) + values.extend(iterable) + + for key, value in keyvalues.iteritems(): + clauses.append("%s = ?" % (key,)) + values.append(value) + + if clauses: + sql = "%s WHERE %s" % ( + sql, + " AND ".join(clauses), + ) + return txn.execute(sql, values) + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value, limit=100000): # Fetch a mapping of room_id -> max stream position for "recent" rooms. @@ -863,16 +969,16 @@ class SQLBaseStore(object): txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) - rows = txn.fetchall() - txn.close() cache = { row[0]: int(row[1]) - for row in rows + for row in txn } + txn.close() + if cache: - min_val = min(cache.values()) + min_val = min(cache.itervalues()) else: min_val = max_value @@ -895,6 +1001,7 @@ class SQLBaseStore(object): # __exit__ called after the transaction finishes. ctx = self._cache_id_gen.get_next() stream_id = ctx.__enter__() + txn.call_on_exception(ctx.__exit__, None, None, None) txn.call_after(ctx.__exit__, None, None, None) txn.call_after(self.hs.get_notifier().on_new_replication_data) diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 3fa226e92d..f83ff0454a 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,18 +14,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore from twisted.internet import defer +from synapse.storage._base import SQLBaseStore +from synapse.storage.util.id_generators import StreamIdGenerator + +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks -import ujson as json +import abc +import simplejson as json import logging logger = logging.getLogger(__name__) -class AccountDataStore(SQLBaseStore): +class AccountDataWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_account_data_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + account_max = self.get_max_account_data_stream_id() + self._account_data_stream_cache = StreamChangeCache( + "AccountDataAndTagsChangeCache", account_max, + ) + + super(AccountDataWorkerStore, self).__init__(db_conn, hs) + + @abc.abstractmethod + def get_max_account_data_stream_id(self): + """Get the current max stream ID for account data stream + + Returns: + int + """ + raise NotImplementedError() @cached() def get_account_data_for_user(self, user_id): @@ -63,7 +92,7 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_user", get_account_data_for_user_txn ) - @cachedInlineCallbacks(num_args=2) + @cachedInlineCallbacks(num_args=2, max_entries=5000) def get_global_account_data_by_type_for_user(self, data_type, user_id): """ Returns: @@ -104,6 +133,7 @@ class AccountDataStore(SQLBaseStore): for row in rows }) + @cached(num_args=2) def get_account_data_for_room(self, user_id, room_id): """Get all the client account_data for a user for a room. @@ -127,6 +157,38 @@ class AccountDataStore(SQLBaseStore): "get_account_data_for_room", get_account_data_for_room_txn ) + @cached(num_args=3, max_entries=5000) + def get_account_data_for_room_and_type(self, user_id, room_id, account_data_type): + """Get the client account_data of given type for a user for a room. + + Args: + user_id(str): The user to get the account_data for. + room_id(str): The room to get the account_data for. + account_data_type (str): The account data type to get. + Returns: + A deferred of the room account_data for that type, or None if + there isn't any set. + """ + def get_account_data_for_room_and_type_txn(txn): + content_json = self._simple_select_one_onecol_txn( + txn, + table="room_account_data", + keyvalues={ + "user_id": user_id, + "room_id": room_id, + "account_data_type": account_data_type, + }, + retcol="content", + allow_none=True + ) + + return json.loads(content_json) if content_json else None + + return self.runInteraction( + "get_account_data_for_room_and_type", + get_account_data_for_room_and_type_txn, + ) + def get_all_updated_account_data(self, last_global_id, last_room_id, current_id, limit): """Get all the client account_data that has changed on the server @@ -182,7 +244,7 @@ class AccountDataStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) global_account_data = { - row[0]: json.loads(row[1]) for row in txn.fetchall() + row[0]: json.loads(row[1]) for row in txn } sql = ( @@ -193,7 +255,7 @@ class AccountDataStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) account_data_by_room = {} - for row in txn.fetchall(): + for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = json.loads(row[2]) @@ -209,6 +271,36 @@ class AccountDataStore(SQLBaseStore): "get_updated_account_data_for_user", get_updated_account_data_for_user_txn ) + @cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000) + def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context): + ignored_account_data = yield self.get_global_account_data_by_type_for_user( + "m.ignored_user_list", ignorer_user_id, + on_invalidate=cache_context.invalidate, + ) + if not ignored_account_data: + defer.returnValue(False) + + defer.returnValue( + ignored_user_id in ignored_account_data.get("ignored_users", {}) + ) + + +class AccountDataStore(AccountDataWorkerStore): + def __init__(self, db_conn, hs): + self._account_data_id_gen = StreamIdGenerator( + db_conn, "account_data_max_stream_id", "stream_id" + ) + + super(AccountDataStore, self).__init__(db_conn, hs) + + def get_max_account_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._account_data_id_gen.get_current_token() + @defer.inlineCallbacks def add_account_data_to_room(self, user_id, room_id, account_data_type, content): """Add some account_data to a room for a user. @@ -222,9 +314,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as room_account_data has a unique constraint + # on (user_id, room_id, account_data_type) so _simple_upsert will + # retry if there is a conflict. + yield self._simple_upsert( + desc="add_room_account_data", table="room_account_data", keyvalues={ "user_id": user_id, @@ -234,18 +329,23 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } - ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, - user_id, next_id, + }, + lock=False, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - self._update_max_stream_id(txn, next_id) - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_room_account_data", add_account_data_txn, next_id + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed(user_id, next_id) + self.get_account_data_for_user.invalidate((user_id,)) + self.get_account_data_for_room.invalidate((user_id, room_id,)) + self.get_account_data_for_room_and_type.prefill( + (user_id, room_id, account_data_type,), content, ) result = self._account_data_id_gen.get_current_token() @@ -263,9 +363,12 @@ class AccountDataStore(SQLBaseStore): """ content_json = json.dumps(content) - def add_account_data_txn(txn, next_id): - self._simple_upsert_txn( - txn, + with self._account_data_id_gen.get_next() as next_id: + # no need to lock here as account_data has a unique constraint on + # (user_id, account_data_type) so _simple_upsert will retry if + # there is a conflict. + yield self._simple_upsert( + desc="add_user_account_data", table="account_data", keyvalues={ "user_id": user_id, @@ -274,37 +377,43 @@ class AccountDataStore(SQLBaseStore): values={ "stream_id": next_id, "content": content_json, - } + }, + lock=False, ) - txn.call_after( - self._account_data_stream_cache.entity_has_changed, + + # it's theoretically possible for the above to succeed and the + # below to fail - in which case we might reuse a stream id on + # restart, and the above update might not get propagated. That + # doesn't sound any worse than the whole update getting lost, + # which is what would happen if we combined the two into one + # transaction. + yield self._update_max_stream_id(next_id) + + self._account_data_stream_cache.entity_has_changed( user_id, next_id, ) - txn.call_after(self.get_account_data_for_user.invalidate, (user_id,)) - txn.call_after( - self.get_global_account_data_by_type_for_user.invalidate, + self.get_account_data_for_user.invalidate((user_id,)) + self.get_global_account_data_by_type_for_user.invalidate( (account_data_type, user_id,) ) - self._update_max_stream_id(txn, next_id) - - with self._account_data_id_gen.get_next() as next_id: - yield self.runInteraction( - "add_user_account_data", add_account_data_txn, next_id - ) result = self._account_data_id_gen.get_current_token() defer.returnValue(result) - def _update_max_stream_id(self, txn, next_id): + def _update_max_stream_id(self, next_id): """Update the max stream_id Args: - txn: The database cursor next_id(int): The the revision to advance to. """ - update_max_id_sql = ( - "UPDATE account_data_max_stream_id" - " SET stream_id = ?" - " WHERE stream_id < ?" + def _update(txn): + update_max_id_sql = ( + "UPDATE account_data_max_stream_id" + " SET stream_id = ?" + " WHERE stream_id < ?" + ) + txn.execute(update_max_id_sql, (next_id, next_id)) + return self.runInteraction( + "update_account_data_max_stream_id", + _update, ) - txn.execute(update_max_id_sql, (next_id, next_id)) diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 514570561f..12ea8a158c 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,39 +14,58 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re import simplejson as json from twisted.internet import defer -from synapse.api.constants import Membership from synapse.appservice import AppServiceTransaction from synapse.config.appservice import load_appservices -from synapse.storage.roommember import RoomsForUser +from synapse.storage.events import EventsWorkerStore from ._base import SQLBaseStore logger = logging.getLogger(__name__) -class ApplicationServiceStore(SQLBaseStore): +def _make_exclusive_regex(services_cache): + # We precompie a regex constructed from all the regexes that the AS's + # have registered for exclusive users. + exclusive_user_regexes = [ + regex.pattern + for service in services_cache + for regex in service.get_exlusive_user_regexes() + ] + if exclusive_user_regexes: + exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes) + exclusive_user_regex = re.compile(exclusive_user_regex) + else: + # We handle this case specially otherwise the constructed regex + # will always match + exclusive_user_regex = None - def __init__(self, hs): - super(ApplicationServiceStore, self).__init__(hs) - self.hostname = hs.hostname + return exclusive_user_regex + + +class ApplicationServiceWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): self.services_cache = load_appservices( hs.hostname, hs.config.app_service_config_files ) + self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) + + super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs) def get_app_services(self): return self.services_cache def get_if_app_services_interested_in_user(self, user_id): - """Check if the user is one associated with an app service + """Check if the user is one associated with an app service (exclusively) """ - for service in self.services_cache: - if service.is_interested_in_user(user_id): - return True - return False + if self.exclusive_user_regex: + return bool(self.exclusive_user_regex.match(user_id)) + else: + return False def get_app_service_by_user_id(self, user_id): """Retrieve an application service from their user ID. @@ -78,83 +98,30 @@ class ApplicationServiceStore(SQLBaseStore): return service return None - def get_app_service_rooms(self, service): - """Get a list of RoomsForUser for this application service. - - Application services may be "interested" in lots of rooms depending on - the room ID, the room aliases, or the members in the room. This function - takes all of these into account and returns a list of RoomsForUser which - represent the entire list of room IDs that this application service - wants to know about. + def get_app_service_by_id(self, as_id): + """Get the application service with the given appservice ID. Args: - service: The application service to get a room list for. + as_id (str): The application service ID. Returns: - A list of RoomsForUser. + synapse.appservice.ApplicationService or None. """ - return self.runInteraction( - "get_app_service_rooms", - self._get_app_service_rooms_txn, - service, - ) - - def _get_app_service_rooms_txn(self, txn, service): - # get all rooms matching the room ID regex. - room_entries = self._simple_select_list_txn( - txn=txn, table="rooms", keyvalues=None, retcols=["room_id"] - ) - matching_room_list = set([ - r["room_id"] for r in room_entries if - service.is_interested_in_room(r["room_id"]) - ]) - - # resolve room IDs for matching room alias regex. - room_alias_mappings = self._simple_select_list_txn( - txn=txn, table="room_aliases", keyvalues=None, - retcols=["room_id", "room_alias"] - ) - matching_room_list |= set([ - r["room_id"] for r in room_alias_mappings if - service.is_interested_in_alias(r["room_alias"]) - ]) - - # get all rooms for every user for this AS. This is scoped to users on - # this HS only. - user_list = self._simple_select_list_txn( - txn=txn, table="users", keyvalues=None, retcols=["name"] - ) - user_list = [ - u["name"] for u in user_list if - service.is_interested_in_user(u["name"]) - ] - rooms_for_user_matching_user_id = set() # RoomsForUser list - for user_id in user_list: - # FIXME: This assumes this store is linked with RoomMemberStore :( - rooms_for_user = self._get_rooms_for_user_where_membership_is_txn( - txn=txn, - user_id=user_id, - membership_list=[Membership.JOIN] - ) - rooms_for_user_matching_user_id |= set(rooms_for_user) - - # make RoomsForUser tuples for room ids and aliases which are not in the - # main rooms_for_user_list - e.g. they are rooms which do not have AS - # registered users in it. - known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id] - missing_rooms_for_user = [ - RoomsForUser(r, service.sender, "join") for r in - matching_room_list if r not in known_room_ids - ] - rooms_for_user_matching_user_id |= set(missing_rooms_for_user) - - return rooms_for_user_matching_user_id + for service in self.services_cache: + if service.id == as_id: + return service + return None -class ApplicationServiceTransactionStore(SQLBaseStore): +class ApplicationServiceStore(ApplicationServiceWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass - def __init__(self, hs): - super(ApplicationServiceTransactionStore, self).__init__(hs) +class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore, + EventsWorkerStore): @defer.inlineCallbacks def get_appservices_by_state(self, state): """Get a list of application services based on their state. @@ -399,3 +366,11 @@ class ApplicationServiceTransactionStore(SQLBaseStore): events = yield self._get_events(event_ids) defer.returnValue((upper_bound, events)) + + +class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore): + # This is currently empty due to there not being any AS storage functions + # that can't be run on the workers. Since this may change in future, and + # to keep consistency with the other stores, we keep this empty class for + # now. + pass diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 94b2bcc54a..8af325a9f5 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,13 +12,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import synapse.util.async from ._base import SQLBaseStore from . import engines from twisted.internet import defer -import ujson as json +import simplejson as json import logging logger = logging.getLogger(__name__) @@ -79,35 +80,26 @@ class BackgroundUpdateStore(SQLBaseStore): BACKGROUND_UPDATE_INTERVAL_MS = 1000 BACKGROUND_UPDATE_DURATION_MS = 100 - def __init__(self, hs): - super(BackgroundUpdateStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(BackgroundUpdateStore, self).__init__(db_conn, hs) self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} - self._background_update_timer = None + self._all_done = False @defer.inlineCallbacks def start_doing_background_updates(self): - assert self._background_update_timer is None, \ - "background updates already running" - logger.info("Starting background schema updates") while True: - sleep = defer.Deferred() - self._background_update_timer = self._clock.call_later( - self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None - ) - try: - yield sleep - finally: - self._background_update_timer = None + yield synapse.util.async.sleep( + self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) try: result = yield self.do_next_background_update( self.BACKGROUND_UPDATE_DURATION_MS ) - except: + except Exception: logger.exception("Error doing update") else: if result is None: @@ -115,9 +107,41 @@ class BackgroundUpdateStore(SQLBaseStore): "No more background updates to do." " Unscheduling background update task." ) + self._all_done = True defer.returnValue(None) @defer.inlineCallbacks + def has_completed_background_updates(self): + """Check if all the background updates have completed + + Returns: + Deferred[bool]: True if all background updates have completed + """ + # if we've previously determined that there is nothing left to do, that + # is easy + if self._all_done: + defer.returnValue(True) + + # obviously, if we have things in our queue, we're not done. + if self._background_update_queue: + defer.returnValue(False) + + # otherwise, check if there are updates to be run. This is important, + # as we may be running on a worker which doesn't perform the bg updates + # itself, but still wants to wait for them to happen. + updates = yield self._simple_select_onecol( + "background_updates", + keyvalues=None, + retcol="1", + desc="check_background_updates", + ) + if not updates: + self._all_done = True + defer.returnValue(True) + + defer.returnValue(False) + + @defer.inlineCallbacks def do_next_background_update(self, desired_duration_ms): """Does some amount of work on the next queued background update @@ -218,8 +242,29 @@ class BackgroundUpdateStore(SQLBaseStore): """ self._background_update_handlers[update_name] = update_handler + def register_noop_background_update(self, update_name): + """Register a noop handler for a background update. + + This is useful when we previously did a background update, but no + longer wish to do the update. In this case the background update should + be removed from the schema delta files, but there may still be some + users who have the background update queued, so this method should + also be called to clear the update. + + Args: + update_name (str): Name of update + """ + @defer.inlineCallbacks + def noop_update(progress, batch_size): + yield self._end_background_update(update_name) + defer.returnValue(1) + + self.register_background_update_handler(update_name, noop_update) + def register_background_index_update(self, update_name, index_name, - table, columns, where_clause=None): + table, columns, where_clause=None, + unique=False, + psql_only=False): """Helper for store classes to do a background index addition To use: @@ -235,48 +280,80 @@ class BackgroundUpdateStore(SQLBaseStore): index_name (str): name of index to add table (str): table to add index to columns (list[str]): columns/expressions to include in index + unique (bool): true to make a UNIQUE index + psql_only: true to only create this index on psql databases (useful + for virtual sqlite tables) """ - # if this is postgres, we add the indexes concurrently. Otherwise - # we fall back to doing it inline - if isinstance(self.database_engine, engines.PostgresEngine): - conc = True - else: - conc = False - # We don't use partial indices on SQLite as it wasn't introduced - # until 3.8, and wheezy has 3.7 - where_clause = None - - sql = ( - "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" - " %(where_clause)s" - ) % { - "conc": "CONCURRENTLY" if conc else "", - "name": index_name, - "table": table, - "columns": ", ".join(columns), - "where_clause": "WHERE " + where_clause if where_clause else "" - } - - def create_index_concurrently(conn): + def create_index_psql(conn): conn.rollback() # postgres insists on autocommit for the index conn.set_session(autocommit=True) - c = conn.cursor() - c.execute(sql) - conn.set_session(autocommit=False) - def create_index(conn): + try: + c = conn.cursor() + + # If a previous attempt to create the index was interrupted, + # we may already have a half-built index. Let's just drop it + # before trying to create it again. + + sql = "DROP INDEX IF EXISTS %s" % (index_name,) + logger.debug("[SQL] %s", sql) + c.execute(sql) + + sql = ( + "CREATE %(unique)s INDEX CONCURRENTLY %(name)s" + " ON %(table)s" + " (%(columns)s) %(where_clause)s" + ) % { + "unique": "UNIQUE" if unique else "", + "name": index_name, + "table": table, + "columns": ", ".join(columns), + "where_clause": "WHERE " + where_clause if where_clause else "" + } + logger.debug("[SQL] %s", sql) + c.execute(sql) + finally: + conn.set_session(autocommit=False) + + def create_index_sqlite(conn): + # Sqlite doesn't support concurrent creation of indexes. + # + # We don't use partial indices on SQLite as it wasn't introduced + # until 3.8, and wheezy and CentOS 7 have 3.7 + # + # We assume that sqlite doesn't give us invalid indices; however + # we may still end up with the index existing but the + # background_updates not having been recorded if synapse got shut + # down at the wrong moment - hance we use IF NOT EXISTS. (SQLite + # has supported CREATE TABLE|INDEX IF NOT EXISTS since 3.3.0.) + sql = ( + "CREATE %(unique)s INDEX IF NOT EXISTS %(name)s ON %(table)s" + " (%(columns)s)" + ) % { + "unique": "UNIQUE" if unique else "", + "name": index_name, + "table": table, + "columns": ", ".join(columns), + } + c = conn.cursor() + logger.debug("[SQL] %s", sql) c.execute(sql) + if isinstance(self.database_engine, engines.PostgresEngine): + runner = create_index_psql + elif psql_only: + runner = None + else: + runner = create_index_sqlite + @defer.inlineCallbacks def updater(progress, batch_size): - logger.info("Adding index %s to %s", index_name, table) - if conc: - yield self.runWithConnection(create_index_concurrently) - else: - yield self.runWithConnection(create_index) + if runner is not None: + logger.info("Adding index %s to %s", index_name, table) + yield self.runWithConnection(runner) yield self._end_background_update(update_name) defer.returnValue(1) diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py index 71e5ea112f..7b44dae0fc 100644 --- a/synapse/storage/client_ips.py +++ b/synapse/storage/client_ips.py @@ -15,11 +15,14 @@ import logging -from twisted.internet import defer +from twisted.internet import defer, reactor from ._base import Cache from . import background_updates +from synapse.util.caches import CACHE_SIZE_FACTOR + + logger = logging.getLogger(__name__) # Number of msec of granularity to store the user IP 'last seen' time. Smaller @@ -29,13 +32,14 @@ LAST_SEEN_GRANULARITY = 120 * 1000 class ClientIpStore(background_updates.BackgroundUpdateStore): - def __init__(self, hs): + def __init__(self, db_conn, hs): self.client_ip_last_seen = Cache( name="client_ip_last_seen", keylen=4, + max_entries=50000 * CACHE_SIZE_FACTOR, ) - super(ClientIpStore, self).__init__(hs) + super(ClientIpStore, self).__init__(db_conn, hs) self.register_background_index_update( "user_ips_device_index", @@ -44,10 +48,26 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): columns=["user_id", "device_id", "last_seen"], ) - @defer.inlineCallbacks - def insert_client_ip(self, user, access_token, ip, user_agent, device_id): - now = int(self._clock.time_msec()) - key = (user.to_string(), access_token, ip) + self.register_background_index_update( + "user_ips_last_seen_index", + index_name="user_ips_last_seen", + table="user_ips", + columns=["user_id", "last_seen"], + ) + + # (user_id, access_token, ip) -> (user_agent, device_id, last_seen) + self._batch_row_update = {} + + self._client_ip_looper = self._clock.looping_call( + self._update_client_ips_batch, 5 * 1000 + ) + reactor.addSystemEventTrigger("before", "shutdown", self._update_client_ips_batch) + + def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id, + now=None): + if not now: + now = int(self._clock.time_msec()) + key = (user_id, access_token, ip) try: last_seen = self.client_ip_last_seen.get(key) @@ -56,34 +76,48 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): # Rate-limited inserts if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: - defer.returnValue(None) + return self.client_ip_last_seen.prefill(key, now) - # It's safe not to lock here: a) no unique constraint, - # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely - yield self._simple_upsert( - "user_ips", - keyvalues={ - "user_id": user.to_string(), - "access_token": access_token, - "ip": ip, - "user_agent": user_agent, - "device_id": device_id, - }, - values={ - "last_seen": now, - }, - desc="insert_client_ip", - lock=False, + self._batch_row_update[key] = (user_agent, device_id, now) + + def _update_client_ips_batch(self): + to_update = self._batch_row_update + self._batch_row_update = {} + return self.runInteraction( + "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) + def _update_client_ips_batch_txn(self, txn, to_update): + self.database_engine.lock_table(txn, "user_ips") + + for entry in to_update.iteritems(): + (user_id, access_token, ip), (user_agent, device_id, last_seen) = entry + + self._simple_upsert_txn( + txn, + table="user_ips", + keyvalues={ + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": device_id, + }, + values={ + "last_seen": last_seen, + }, + lock=False, + ) + @defer.inlineCallbacks - def get_last_client_ip_by_device(self, devices): + def get_last_client_ip_by_device(self, user_id, device_id): """For each device_id listed, give the user_ip it was last seen on Args: - devices (iterable[(str, str)]): list of (user_id, device_id) pairs + user_id (str) + device_id (str): If None fetches all devices for the user Returns: defer.Deferred: resolves to a dict, where the keys @@ -94,6 +128,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): res = yield self.runInteraction( "get_last_client_ip_by_device", self._get_last_client_ip_by_device_txn, + user_id, device_id, retcols=( "user_id", "access_token", @@ -102,23 +137,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): "device_id", "last_seen", ), - devices=devices ) ret = {(d["user_id"], d["device_id"]): d for d in res} + for key in self._batch_row_update: + uid, access_token, ip = key + if uid == user_id: + user_agent, did, last_seen = self._batch_row_update[key] + if not device_id or did == device_id: + ret[(user_id, device_id)] = { + "user_id": user_id, + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "device_id": did, + "last_seen": last_seen, + } defer.returnValue(ret) @classmethod - def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols): + def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols): where_clauses = [] bindings = [] - for (user_id, device_id) in devices: - if device_id is None: - where_clauses.append("(user_id = ? AND device_id IS NULL)") - bindings.extend((user_id, )) - else: - where_clauses.append("(user_id = ? AND device_id = ?)") - bindings.extend((user_id, device_id)) + if device_id is None: + where_clauses.append("user_id = ?") + bindings.extend((user_id, )) + else: + where_clauses.append("(user_id = ? AND device_id = ?)") + bindings.extend((user_id, device_id)) + + if not where_clauses: + return [] inner_select = ( "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " @@ -143,3 +192,37 @@ class ClientIpStore(background_updates.BackgroundUpdateStore): txn.execute(sql, bindings) return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def get_user_ip_and_agents(self, user): + user_id = user.to_string() + results = {} + + for key in self._batch_row_update: + uid, access_token, ip = key + if uid == user_id: + user_agent, _, last_seen = self._batch_row_update[key] + results[(access_token, ip)] = (user_agent, last_seen) + + rows = yield self._simple_select_list( + table="user_ips", + keyvalues={"user_id": user_id}, + retcols=[ + "access_token", "ip", "user_agent", "last_seen" + ], + desc="get_user_ip_and_agents", + ) + + results.update( + ((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"])) + for row in rows + ) + defer.returnValue(list( + { + "access_token": access_token, + "ip": ip, + "user_agent": user_agent, + "last_seen": last_seen, + } + for (access_token, ip), (user_agent, last_seen) in results.iteritems() + )) diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index bde3b5cbbc..a879e5bfc1 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -14,12 +14,14 @@ # limitations under the License. import logging -import ujson +import simplejson from twisted.internet import defer from .background_updates import BackgroundUpdateStore +from synapse.util.caches.expiringcache import ExpiringCache + logger = logging.getLogger(__name__) @@ -27,8 +29,8 @@ logger = logging.getLogger(__name__) class DeviceInboxStore(BackgroundUpdateStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - def __init__(self, hs): - super(DeviceInboxStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(DeviceInboxStore, self).__init__(db_conn, hs) self.register_background_index_update( "device_inbox_stream_index", @@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore): self._background_drop_index_device_inbox, ) + # Map of (user_id, device_id) to the last stream_id that has been + # deleted up to. This is so that we can no op deletions. + self._last_device_delete_cache = ExpiringCache( + cache_name="last_device_delete_cache", + clock=self._clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) + @defer.inlineCallbacks def add_messages_to_device_inbox(self, local_messages_by_user_then_device, remote_messages_by_destination): @@ -74,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore): ) rows = [] for destination, edu in remote_messages_by_destination.items(): - edu_json = ujson.dumps(edu) + edu_json = simplejson.dumps(edu) rows.append((destination, stream_id, now_ms, edu_json)) txn.executemany(sql, rows) @@ -166,8 +177,8 @@ class DeviceInboxStore(BackgroundUpdateStore): " WHERE user_id = ?" ) txn.execute(sql, (user_id,)) - message_json = ujson.dumps(messages_by_device["*"]) - for row in txn.fetchall(): + message_json = simplejson.dumps(messages_by_device["*"]) + for row in txn: # Add the message for all devices for this user on this # server. device = row[0] @@ -184,11 +195,11 @@ class DeviceInboxStore(BackgroundUpdateStore): # TODO: Maybe this needs to be done in batches if there are # too many local devices for a given user. txn.execute(sql, [user_id] + devices) - for row in txn.fetchall(): + for row in txn: # Only insert into the local inbox if the device exists on # this server device = row[0] - message_json = ujson.dumps(messages_by_device[device]) + message_json = simplejson.dumps(messages_by_device[device]) messages_json_for_user[device] = message_json if messages_json_for_user: @@ -240,9 +251,9 @@ class DeviceInboxStore(BackgroundUpdateStore): user_id, device_id, last_stream_id, current_stream_id, limit )) messages = [] - for row in txn.fetchall(): + for row in txn: stream_pos = row[0] - messages.append(ujson.loads(row[1])) + messages.append(simplejson.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id return (messages, stream_pos) @@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore): "get_new_messages_for_device", get_new_messages_for_device_txn, ) + @defer.inlineCallbacks def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): """ Args: @@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore): Returns: A deferred that resolves to the number of messages deleted. """ + # If we have cached the last stream id we've deleted up to, we can + # check if there is likely to be anything that needs deleting + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), None + ) + if last_deleted_stream_id: + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_deleted_stream_id + ) + if not has_changed: + defer.returnValue(0) + def delete_messages_for_device_txn(txn): sql = ( "DELETE FROM device_inbox" @@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - return self.runInteraction( + count = yield self.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) + # Update the cache, ensuring that we only ever increase the value + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), 0 + ) + self._last_device_delete_cache[(user_id, device_id)] = max( + last_deleted_stream_id, up_to_stream_id + ) + + defer.returnValue(count) + def get_all_new_device_messages(self, last_pos, current_pos, limit): """ Args: @@ -291,22 +325,25 @@ class DeviceInboxStore(BackgroundUpdateStore): # we return. upper_pos = min(current_pos, last_pos + limit) sql = ( - "SELECT stream_id, user_id" + "SELECT max(stream_id), user_id" " FROM device_inbox" " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" + " GROUP BY user_id" ) txn.execute(sql, (last_pos, upper_pos)) rows = txn.fetchall() sql = ( - "SELECT stream_id, destination" + "SELECT max(stream_id), destination" " FROM device_federation_outbox" " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" + " GROUP BY destination" ) txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn.fetchall()) + rows.extend(txn) + + # Order by ascending stream ordering + rows.sort() return rows @@ -323,12 +360,12 @@ class DeviceInboxStore(BackgroundUpdateStore): """ Args: destination(str): The name of the remote server. - last_stream_id(int): The last position of the device message stream + last_stream_id(int|long): The last position of the device message stream that the server sent up to. - current_stream_id(int): The current position of the device + current_stream_id(int|long): The current position of the device message stream. Returns: - Deferred ([dict], int): List of messages for the device and where + Deferred ([dict], int|long): List of messages for the device and where in the stream the messages got to. """ @@ -350,9 +387,9 @@ class DeviceInboxStore(BackgroundUpdateStore): destination, last_stream_id, current_stream_id, limit )) messages = [] - for row in txn.fetchall(): + for row in txn: stream_pos = row[0] - messages.append(ujson.loads(row[1])) + messages.append(simplejson.loads(row[1])) if len(messages) < limit: stream_pos = current_stream_id return (messages, stream_pos) diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 8e17800364..712106b83a 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -13,24 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import ujson as json +import simplejson as json from twisted.internet import defer from synapse.api.errors import StoreError -from ._base import SQLBaseStore +from ._base import SQLBaseStore, Cache +from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks + logger = logging.getLogger(__name__) class DeviceStore(SQLBaseStore): - def __init__(self, hs): - super(DeviceStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(DeviceStore, self).__init__(db_conn, hs) + + # Map of (user_id, device_id) -> bool. If there is an entry that implies + # the device exists. + self.device_id_exists_cache = Cache( + name="device_id_exists", + keylen=2, + max_entries=10000, + ) self._clock.looping_call( self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) + self.register_background_index_update( + "device_lists_stream_idx", + index_name="device_lists_stream_user_id", + table="device_lists_stream", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks def store_device(self, user_id, device_id, initial_device_display_name): @@ -45,6 +62,10 @@ class DeviceStore(SQLBaseStore): defer.Deferred: boolean whether the device was inserted or an existing device existed with that ID. """ + key = (user_id, device_id) + if self.device_id_exists_cache.get(key, None): + defer.returnValue(False) + try: inserted = yield self._simple_insert( "devices", @@ -56,6 +77,7 @@ class DeviceStore(SQLBaseStore): desc="store_device", or_ignore=True, ) + self.device_id_exists_cache.prefill(key, True) defer.returnValue(inserted) except Exception as e: logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" @@ -84,6 +106,7 @@ class DeviceStore(SQLBaseStore): desc="get_device", ) + @defer.inlineCallbacks def delete_device(self, user_id, device_id): """Delete a device. @@ -93,12 +116,34 @@ class DeviceStore(SQLBaseStore): Returns: defer.Deferred """ - return self._simple_delete_one( + yield self._simple_delete_one( table="devices", keyvalues={"user_id": user_id, "device_id": device_id}, desc="delete_device", ) + self.device_id_exists_cache.invalidate((user_id, device_id)) + + @defer.inlineCallbacks + def delete_devices(self, user_id, device_ids): + """Deletes several devices. + + Args: + user_id (str): The ID of the user which owns the devices + device_ids (list): The IDs of the devices to delete + Returns: + defer.Deferred + """ + yield self._simple_delete_many( + table="devices", + column="device_id", + iterable=device_ids, + keyvalues={"user_id": user_id}, + desc="delete_devices", + ) + for device_id in device_ids: + self.device_id_exists_cache.invalidate((user_id, device_id)) + def update_device(self, user_id, device_id, new_display_name=None): """Update a device. @@ -144,6 +189,7 @@ class DeviceStore(SQLBaseStore): defer.returnValue({d["device_id"]: d for d in devices}) + @cached(max_entries=10000) def get_device_list_last_stream_id_for_remote(self, user_id): """Get the last stream_id we got for a user. May be None if we haven't got any information for them. @@ -156,16 +202,36 @@ class DeviceStore(SQLBaseStore): allow_none=True, ) + @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", + list_name="user_ids", inlineCallbacks=True) + def get_device_list_last_stream_id_for_remotes(self, user_ids): + rows = yield self._simple_select_many_batch( + table="device_lists_remote_extremeties", + column="user_id", + iterable=user_ids, + retcols=("user_id", "stream_id",), + desc="get_user_devices_from_cache", + ) + + results = {user_id: None for user_id in user_ids} + results.update({ + row["user_id"]: row["stream_id"] for row in rows + }) + + defer.returnValue(results) + + @defer.inlineCallbacks def mark_remote_user_device_list_as_unsubscribed(self, user_id): """Mark that we no longer track device lists for remote user. """ - return self._simple_delete( + yield self._simple_delete( table="device_lists_remote_extremeties", keyvalues={ "user_id": user_id, }, desc="mark_remote_user_device_list_as_unsubscribed", ) + self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) def update_remote_device_list_cache_entry(self, user_id, device_id, content, stream_id): @@ -191,6 +257,12 @@ class DeviceStore(SQLBaseStore): } ) + txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + self._simple_upsert_txn( txn, table="device_lists_remote_extremeties", @@ -234,6 +306,12 @@ class DeviceStore(SQLBaseStore): ] ) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + self._simple_upsert_txn( txn, table="device_lists_remote_extremeties", @@ -249,7 +327,7 @@ class DeviceStore(SQLBaseStore): """Get stream of updates to send to remote servers Returns: - (now_stream_id, [ { updates }, .. ]) + (int, list[dict]): current stream id and list of updates """ now_stream_id = self._device_list_id_gen.get_current_token() @@ -270,24 +348,27 @@ class DeviceStore(SQLBaseStore): SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? GROUP BY user_id, device_id + LIMIT 20 """ txn.execute( sql, (destination, from_stream_id, now_stream_id, False) ) - rows = txn.fetchall() - if not rows: + # maps (user_id, device_id) -> stream_id + query_map = {(r[0], r[1]): r[2] for r in txn} + if not query_map: return (now_stream_id, []) - # maps (user_id, device_id) -> stream_id - query_map = {(r[0], r[1]): r[2] for r in rows} + if len(query_map) >= 20: + now_stream_id = max(stream_id for stream_id in query_map.itervalues()) + devices = self._get_e2e_device_keys_txn( txn, query_map.keys(), include_all_devices=True ) prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id - FROM device_lists_outbound_pokes + FROM device_lists_outbound_last_success WHERE destination = ? AND user_id = ? AND stream_id <= ? """ @@ -320,6 +401,7 @@ class DeviceStore(SQLBaseStore): return (now_stream_id, results) + @defer.inlineCallbacks def get_user_devices_from_cache(self, query_list): """Get the devices (and keys if any) for remote users from the cache. @@ -332,27 +414,11 @@ class DeviceStore(SQLBaseStore): a set of user_ids and results_map is a mapping of user_id -> device_id -> device_info """ - return self.runInteraction( - "get_user_devices_from_cache", self._get_user_devices_from_cache_txn, - query_list, + user_ids = set(user_id for user_id, _ in query_list) + user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_ids_in_cache = set( + user_id for user_id, stream_id in user_map.items() if stream_id ) - - def _get_user_devices_from_cache_txn(self, txn, query_list): - user_ids = {user_id for user_id, _ in query_list} - - user_ids_in_cache = set() - for user_id in user_ids: - stream_ids = self._simple_select_onecol_txn( - txn, - table="device_lists_remote_extremeties", - keyvalues={ - "user_id": user_id, - }, - retcol="stream_id", - ) - if stream_ids: - user_ids_in_cache.add(user_id) - user_ids_not_in_cache = user_ids - user_ids_in_cache results = {} @@ -361,32 +427,40 @@ class DeviceStore(SQLBaseStore): continue if device_id: - content = self._simple_select_one_onecol_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - retcol="content", - ) - results.setdefault(user_id, {})[device_id] = json.loads(content) + device = yield self._get_cached_user_device(user_id, device_id) + results.setdefault(user_id, {})[device_id] = device else: - devices = self._simple_select_list_txn( - txn, - table="device_lists_remote_cache", - keyvalues={ - "user_id": user_id, - }, - retcols=("device_id", "content"), - ) - results[user_id] = { - device["device_id"]: json.loads(device["content"]) - for device in devices - } - user_ids_in_cache.discard(user_id) + results[user_id] = yield self._get_cached_devices_for_user(user_id) - return user_ids_not_in_cache, results + defer.returnValue((user_ids_not_in_cache, results)) + + @cachedInlineCallbacks(num_args=2, tree=True) + def _get_cached_user_device(self, user_id, device_id): + content = yield self._simple_select_one_onecol( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + desc="_get_cached_user_device", + ) + defer.returnValue(json.loads(content)) + + @cachedInlineCallbacks() + def _get_cached_devices_for_user(self, user_id): + devices = yield self._simple_select_list( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + desc="_get_cached_devices_for_user", + ) + defer.returnValue({ + device["device_id"]: json.loads(device["content"]) + for device in devices + }) def get_devices_with_keys_by_user(self, user_id): """Get all devices (with any device keys) for a user @@ -436,32 +510,43 @@ class DeviceStore(SQLBaseStore): ) def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): - # First we DELETE all rows such that only the latest row for each - # (destination, user_id is left. We do this by selecting first and - # deleting. + # We update the device_lists_outbound_last_success with the successfully + # poked users. We do the join to see which users need to be inserted and + # which updated. sql = """ - SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes - WHERE destination = ? AND stream_id <= ? + SELECT user_id, coalesce(max(o.stream_id), 0), (max(s.stream_id) IS NOT NULL) + FROM device_lists_outbound_pokes as o + LEFT JOIN device_lists_outbound_last_success as s + USING (destination, user_id) + WHERE destination = ? AND o.stream_id <= ? GROUP BY user_id - HAVING count(*) > 1 """ txn.execute(sql, (destination, stream_id,)) rows = txn.fetchall() sql = """ - DELETE FROM device_lists_outbound_pokes - WHERE destination = ? AND user_id = ? AND stream_id < ? + UPDATE device_lists_outbound_last_success + SET stream_id = ? + WHERE destination = ? AND user_id = ? + """ + txn.executemany( + sql, ((row[1], destination, row[0],) for row in rows if row[2]) + ) + + sql = """ + INSERT INTO device_lists_outbound_last_success + (destination, user_id, stream_id) VALUES (?, ?, ?) """ txn.executemany( - sql, ((destination, row[0], row[1],) for row in rows) + sql, ((destination, row[0], row[1],) for row in rows if not row[2]) ) - # Mark everything that is left as sent + # Delete all sent outbound pokes sql = """ - UPDATE device_lists_outbound_pokes SET sent = ? + DELETE FROM device_lists_outbound_pokes WHERE destination = ? AND stream_id <= ? """ - txn.execute(sql, (True, destination, stream_id,)) + txn.execute(sql, (destination, stream_id,)) @defer.inlineCallbacks def get_user_whose_devices_changed(self, from_key): @@ -473,12 +558,12 @@ class DeviceStore(SQLBaseStore): defer.returnValue(set(changed)) sql = """ - SELECT user_id FROM device_lists_stream WHERE stream_id > ? + SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? """ rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) defer.returnValue(set(row[0] for row in rows)) - def get_all_device_list_changes_for_remotes(self, from_key): + def get_all_device_list_changes_for_remotes(self, from_key, to_key): """Return a list of `(stream_id, user_id, destination)` which is the combined list of changes to devices, and which destinations need to be poked. `destination` may be None if no destinations need to be poked. @@ -486,11 +571,11 @@ class DeviceStore(SQLBaseStore): sql = """ SELECT stream_id, user_id, destination FROM device_lists_stream LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) - WHERE stream_id > ? + WHERE ? < stream_id AND stream_id <= ? """ return self._execute( - "get_users_and_hosts_device_list", None, - sql, from_key, + "get_all_device_list_changes_for_remotes", None, + sql, from_key, to_key ) @defer.inlineCallbacks @@ -518,6 +603,16 @@ class DeviceStore(SQLBaseStore): host, stream_id, ) + # Delete older entries in the table, as we really only care about + # when the latest change happened. + txn.executemany( + """ + DELETE FROM device_lists_stream + WHERE user_id = ? AND device_id = ? AND stream_id < ? + """, + [(user_id, device_id, stream_id) for device_id in device_ids] + ) + self._simple_insert_many_txn( txn, table="device_lists_stream", @@ -586,6 +681,14 @@ class DeviceStore(SQLBaseStore): ) ) + # Since we've deleted unsent deltas, we need to remove the entry + # of last successful sent so that the prev_ids are correctly set. + sql = """ + DELETE FROM device_lists_outbound_last_success + WHERE destination = ? AND user_id = ? + """ + txn.executemany(sql, ((row[0], row[1]) for row in rows)) + logger.info("Pruned %d device list outbound pokes", txn.rowcount) return self.runInteraction( diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py index 9caaf81f2c..d0c0059757 100644 --- a/synapse/storage/directory.py +++ b/synapse/storage/directory.py @@ -29,8 +29,7 @@ RoomAliasMapping = namedtuple( ) -class DirectoryStore(SQLBaseStore): - +class DirectoryWorkerStore(SQLBaseStore): @defer.inlineCallbacks def get_association_from_room_alias(self, room_alias): """ Get's the room_id and server list for a given room_alias @@ -69,6 +68,28 @@ class DirectoryStore(SQLBaseStore): RoomAliasMapping(room_id, room_alias.to_string(), servers) ) + def get_room_alias_creator(self, room_alias): + return self._simple_select_one_onecol( + table="room_aliases", + keyvalues={ + "room_alias": room_alias, + }, + retcol="creator", + desc="get_room_alias_creator", + allow_none=True + ) + + @cached(max_entries=5000) + def get_aliases_for_room(self, room_id): + return self._simple_select_onecol( + "room_aliases", + {"room_id": room_id}, + "room_alias", + desc="get_aliases_for_room", + ) + + +class DirectoryStore(DirectoryWorkerStore): @defer.inlineCallbacks def create_room_alias_association(self, room_alias, room_id, servers, creator=None): """ Creates an associatin between a room alias and room_id/servers @@ -116,17 +137,6 @@ class DirectoryStore(SQLBaseStore): ) defer.returnValue(ret) - def get_room_alias_creator(self, room_alias): - return self._simple_select_one_onecol( - table="room_aliases", - keyvalues={ - "room_alias": room_alias, - }, - retcol="creator", - desc="get_room_alias_creator", - allow_none=True - ) - @defer.inlineCallbacks def delete_room_alias(self, room_alias): room_id = yield self.runInteraction( @@ -135,7 +145,6 @@ class DirectoryStore(SQLBaseStore): room_alias, ) - self.get_aliases_for_room.invalidate((room_id,)) defer.returnValue(room_id) def _delete_room_alias_txn(self, txn, room_alias): @@ -160,13 +169,22 @@ class DirectoryStore(SQLBaseStore): (room_alias.to_string(),) ) + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (room_id,) + ) + return room_id - @cached(max_entries=5000) - def get_aliases_for_room(self, room_id): - return self._simple_select_onecol( - "room_aliases", - {"room_id": room_id}, - "room_alias", - desc="get_aliases_for_room", + def update_aliases_for_room(self, old_room_id, new_room_id, creator): + def _update_aliases_for_room_txn(txn): + sql = "UPDATE room_aliases SET room_id = ?, creator = ? WHERE room_id = ?" + txn.execute(sql, (new_room_id, creator, old_room_id,)) + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (old_room_id,) + ) + self._invalidate_cache_and_stream( + txn, self.get_aliases_for_room, (new_room_id,) + ) + return self.runInteraction( + "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index b9f1365f92..ff8538ddf8 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -14,8 +14,10 @@ # limitations under the License. from twisted.internet import defer +from synapse.util.caches.descriptors import cached + from canonicaljson import encode_canonical_json -import ujson as json +import simplejson as json from ._base import SQLBaseStore @@ -120,26 +122,77 @@ class EndToEndKeyStore(SQLBaseStore): return result - def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): + @defer.inlineCallbacks + def get_e2e_one_time_keys(self, user_id, device_id, key_ids): + """Retrieve a number of one-time keys for a user + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + key_ids(list[str]): list of key ids (excluding algorithm) to + retrieve + + Returns: + deferred resolving to Dict[(str, str), str]: map from (algorithm, + key_id) to json string for key + """ + + rows = yield self._simple_select_many_batch( + table="e2e_one_time_keys_json", + column="key_id", + iterable=key_ids, + retcols=("algorithm", "key_id", "key_json",), + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + desc="add_e2e_one_time_keys_check", + ) + + defer.returnValue({ + (row["algorithm"], row["key_id"]): row["key_json"] for row in rows + }) + + @defer.inlineCallbacks + def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys): + """Insert some new one time keys for a device. Errors if any of the + keys already exist. + + Args: + user_id(str): id of user to get keys for + device_id(str): id of device to get keys for + time_now(long): insertion time to record (ms since epoch) + new_keys(iterable[(str, str, str)]: keys to add - each a tuple of + (algorithm, key_id, key json) + """ + def _add_e2e_one_time_keys(txn): - for (algorithm, key_id, json_bytes) in key_list: - self._simple_upsert_txn( - txn, table="e2e_one_time_keys_json", - keyvalues={ + # We are protected from race between lookup and insertion due to + # a unique constraint. If there is a race of two calls to + # `add_e2e_one_time_keys` then they'll conflict and we will only + # insert one set. + self._simple_insert_many_txn( + txn, table="e2e_one_time_keys_json", + values=[ + { "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, - }, - values={ "ts_added_ms": time_now, "key_json": json_bytes, } - ) - return self.runInteraction( - "add_e2e_one_time_keys", _add_e2e_one_time_keys + for algorithm, key_id, json_bytes in new_keys + ], + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) + ) + yield self.runInteraction( + "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) + @cached(max_entries=10000) def count_e2e_one_time_keys(self, user_id, device_id): """ Count the number of one time keys the server has for a device Returns: @@ -153,7 +206,7 @@ class EndToEndKeyStore(SQLBaseStore): ) txn.execute(sql, (user_id, device_id)) result = {} - for algorithm, key_count in txn.fetchall(): + for algorithm, key_count in txn: result[algorithm] = key_count return result return self.runInteraction( @@ -174,7 +227,7 @@ class EndToEndKeyStore(SQLBaseStore): user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn.fetchall(): + for key_id, key_json in txn: device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) sql = ( @@ -184,20 +237,29 @@ class EndToEndKeyStore(SQLBaseStore): ) for user_id, device_id, algorithm, key_id in delete: txn.execute(sql, (user_id, device_id, algorithm, key_id)) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) + ) return result return self.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - @defer.inlineCallbacks def delete_e2e_keys_by_device(self, user_id, device_id): - yield self._simple_delete( - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_e2e_device_keys_by_device" - ) - yield self._simple_delete( - table="e2e_one_time_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - desc="delete_e2e_one_time_keys_by_device" + def delete_e2e_keys_by_device_txn(txn): + self._simple_delete_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._simple_delete_txn( + txn, + table="e2e_one_time_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id,) + ) + return self.runInteraction( + "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn ) diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 338b495611..8c868ece75 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -18,6 +18,7 @@ from .postgres import PostgresEngine from .sqlite3 import Sqlite3Engine import importlib +import platform SUPPORTED_MODULE = { @@ -31,6 +32,10 @@ def create_engine(database_config): engine_class = SUPPORTED_MODULE.get(name, None) if engine_class: + # pypy requires psycopg2cffi rather than psycopg2 + if (name == "psycopg2" and + platform.python_implementation() == "PyPy"): + name = "psycopg2cffi" module = importlib.import_module(name) return engine_class(module, database_config) diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index a6ae79dfad..8a0386c1a4 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -62,3 +62,9 @@ class PostgresEngine(object): def lock_table(self, txn, table): txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,)) + + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + txn.execute("SELECT nextval('state_group_id_seq')") + return txn.fetchone()[0] diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py index 755c9a1f07..60f0fa7fb3 100644 --- a/synapse/storage/engines/sqlite3.py +++ b/synapse/storage/engines/sqlite3.py @@ -16,6 +16,7 @@ from synapse.storage.prepare_database import prepare_database import struct +import threading class Sqlite3Engine(object): @@ -24,6 +25,11 @@ class Sqlite3Engine(object): def __init__(self, database_module, database_config): self.module = database_module + # The current max state_group, or None if we haven't looked + # in the DB yet. + self._current_state_group_id = None + self._current_state_group_id_lock = threading.Lock() + def check_database(self, txn): pass @@ -43,6 +49,19 @@ class Sqlite3Engine(object): def lock_table(self, txn, table): return + def get_next_state_group_id(self, txn): + """Returns an int that can be used as a new state_group ID + """ + # We do application locking here since if we're using sqlite then + # we are a single process synapse. + with self._current_state_group_id_lock: + if self._current_state_group_id is None: + txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups") + self._current_state_group_id = txn.fetchone()[0] + + self._current_state_group_id += 1 + return self._current_state_group_id + # Following functions taken from: https://github.com/coleifer/peewee diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 256e50dc20..8fbf7ffba7 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -12,50 +12,64 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from twisted.internet import defer -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore +from synapse.storage.signatures import SignatureWorkerStore + from synapse.api.errors import StoreError from synapse.util.caches.descriptors import cached from unpaddedbase64 import encode_base64 import logging -from Queue import PriorityQueue, Empty +from six.moves.queue import PriorityQueue, Empty + +from six.moves import range logger = logging.getLogger(__name__) -class EventFederationStore(SQLBaseStore): - """ Responsible for storing and serving up the various graphs associated - with an event. Including the main event graph and the auth chains for an - event. +class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, + SQLBaseStore): + def get_auth_chain(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. - Also has methods for getting the front (latest) and back (oldest) edges - of the event graphs. These are used to generate the parents for new events - and backfilling from another server respectively. - """ + Args: + event_ids (list): state events + include_given (bool): include the given events in result - def __init__(self, hs): - super(EventFederationStore, self).__init__(hs) + Returns: + list of events + """ + return self.get_auth_chain_ids( + event_ids, include_given=include_given, + ).addCallback(self._get_events) - hs.get_clock().looping_call( - self._delete_old_forward_extrem_cache, 60 * 60 * 1000 - ) + def get_auth_chain_ids(self, event_ids, include_given=False): + """Get auth events for given event_ids. The events *must* be state events. - def get_auth_chain(self, event_ids): - return self.get_auth_chain_ids(event_ids).addCallback(self._get_events) + Args: + event_ids (list): state events + include_given (bool): include the given events in result - def get_auth_chain_ids(self, event_ids): + Returns: + list of event_ids + """ return self.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, - event_ids + event_ids, include_given ) - def _get_auth_chain_ids_txn(self, txn, event_ids): - results = set() + def _get_auth_chain_ids_txn(self, txn, event_ids, include_given): + if include_given: + results = set(event_ids) + else: + results = set() base_sql = ( "SELECT auth_id FROM event_auth WHERE event_id IN (%s)" @@ -67,14 +81,14 @@ class EventFederationStore(SQLBaseStore): front_list = list(front) chunks = [ front_list[x:x + 100] - for x in xrange(0, len(front), 100) + for x in range(0, len(front), 100) ] for chunk in chunks: txn.execute( base_sql % (",".join(["?"] * len(chunk)),), chunk ) - new_front.update([r[0] for r in txn.fetchall()]) + new_front.update([r[0] for r in txn]) new_front -= results @@ -110,7 +124,7 @@ class EventFederationStore(SQLBaseStore): txn.execute(sql, (room_id, False,)) - return dict(txn.fetchall()) + return dict(txn) def _get_oldest_events_in_room_txn(self, txn, room_id): return self._simple_select_onecol_txn( @@ -122,7 +136,47 @@ class EventFederationStore(SQLBaseStore): retcol="event_id", ) + @defer.inlineCallbacks + def get_prev_events_for_room(self, room_id): + """ + Gets a subset of the current forward extremities in the given room. + + Limits the result to 10 extremities, so that we can avoid creating + events which refer to hundreds of prev_events. + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + res = yield self.get_latest_event_ids_and_hashes_in_room(room_id) + if len(res) > 10: + # Sort by reverse depth, so we point to the most recent. + res.sort(key=lambda a: -a[2]) + + # we use half of the limit for the actual most recent events, and + # the other half to randomly point to some of the older events, to + # make sure that we don't completely ignore the older events. + res = res[0:5] + random.sample(res[5:], 5) + + defer.returnValue(res) + def get_latest_event_ids_and_hashes_in_room(self, room_id): + """ + Gets the current forward extremities in the given room + + Args: + room_id (str): room_id + + Returns: + Deferred[list[(str, dict[str, str], int)]] + for each event, a tuple of (event_id, hashes, depth) + where *hashes* is a map from algorithm to hash. + """ + return self.runInteraction( "get_latest_event_ids_and_hashes_in_room", self._get_latest_event_ids_and_hashes_in_room, @@ -171,22 +225,6 @@ class EventFederationStore(SQLBaseStore): room_id, ) - @defer.inlineCallbacks - def get_max_depth_of_events(self, event_ids): - sql = ( - "SELECT MAX(depth) FROM events WHERE event_id IN (%s)" - ) % (",".join(["?"] * len(event_ids)),) - - rows = yield self._execute( - "get_max_depth_of_events", None, - sql, *event_ids - ) - - if rows: - defer.returnValue(rows[0][0]) - else: - defer.returnValue(1) - def _get_min_depth_interaction(self, txn, room_id): min_depth = self._simple_select_one_onecol_txn( txn, @@ -198,88 +236,6 @@ class EventFederationStore(SQLBaseStore): return int(min_depth) if min_depth is not None else None - def _update_min_depth_for_room_txn(self, txn, room_id, depth): - min_depth = self._get_min_depth_interaction(txn, room_id) - - do_insert = depth < min_depth if min_depth else True - - if do_insert: - self._simple_upsert_txn( - txn, - table="room_depth", - keyvalues={ - "room_id": room_id, - }, - values={ - "min_depth": depth, - }, - ) - - def _handle_mult_prev_events(self, txn, events): - """ - For the given event, update the event edges table and forward and - backward extremities tables. - """ - self._simple_insert_many_txn( - txn, - table="event_edges", - values=[ - { - "event_id": ev.event_id, - "prev_event_id": e_id, - "room_id": ev.room_id, - "is_state": False, - } - for ev in events - for e_id, _ in ev.prev_events - ], - ) - - self._update_backward_extremeties(txn, events) - - def _update_backward_extremeties(self, txn, events): - """Updates the event_backward_extremities tables based on the new/updated - events being persisted. - - This is called for new events *and* for events that were outliers, but - are now being persisted as non-outliers. - - Forward extremities are handled when we first start persisting the events. - """ - events_by_room = {} - for ev in events: - events_by_room.setdefault(ev.room_id, []).append(ev) - - query = ( - "INSERT INTO event_backward_extremities (event_id, room_id)" - " SELECT ?, ? WHERE NOT EXISTS (" - " SELECT 1 FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - " )" - " AND NOT EXISTS (" - " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " - " AND outlier = ?" - " )" - ) - - txn.executemany(query, [ - (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) - for ev in events for e_id, _ in ev.prev_events - if not ev.internal_metadata.is_outlier() - ]) - - query = ( - "DELETE FROM event_backward_extremities" - " WHERE event_id = ? AND room_id = ?" - ) - txn.executemany( - query, - [ - (ev.event_id, ev.room_id) for ev in events - if not ev.internal_metadata.is_outlier() - ] - ) - def get_forward_extremeties_for_room(self, room_id, stream_ordering): """For a given room_id and stream_ordering, return the forward extremeties of the room at that point in "time". @@ -334,36 +290,13 @@ class EventFederationStore(SQLBaseStore): def get_forward_extremeties_for_room_txn(txn): txn.execute(sql, (stream_ordering, room_id)) - rows = txn.fetchall() - return [event_id for event_id, in rows] + return [event_id for event_id, in txn] return self.runInteraction( "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn ) - def _delete_old_forward_extrem_cache(self): - def _delete_old_forward_extrem_cache_txn(txn): - # Delete entries older than a month, while making sure we don't delete - # the only entries for a room. - sql = (""" - DELETE FROM stream_ordering_to_exterm - WHERE - room_id IN ( - SELECT room_id - FROM stream_ordering_to_exterm - WHERE stream_ordering > ? - ) AND stream_ordering < ? - """) - txn.execute( - sql, - (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) - ) - return self.runInteraction( - "_delete_old_forward_extrem_cache", - _delete_old_forward_extrem_cache_txn - ) - def get_backfill_events(self, room_id, event_list, limit): """Get a list of Events for a given topic that occurred before (and including) the events in event_list. Return a list of max size `limit` @@ -436,7 +369,7 @@ class EventFederationStore(SQLBaseStore): (room_id, event_id, False, limit - len(event_results)) ) - for row in txn.fetchall(): + for row in txn: if row[1] not in event_results: queue.put((-row[0], row[1])) @@ -482,7 +415,7 @@ class EventFederationStore(SQLBaseStore): (room_id, event_id, False, limit - len(event_results)) ) - for e_id, in txn.fetchall(): + for e_id, in txn: new_front.add(e_id) new_front -= earliest_events @@ -493,6 +426,135 @@ class EventFederationStore(SQLBaseStore): return event_results + +class EventFederationStore(EventFederationWorkerStore): + """ Responsible for storing and serving up the various graphs associated + with an event. Including the main event graph and the auth chains for an + event. + + Also has methods for getting the front (latest) and back (oldest) edges + of the event graphs. These are used to generate the parents for new events + and backfilling from another server respectively. + """ + + EVENT_AUTH_STATE_ONLY = "event_auth_state_only" + + def __init__(self, db_conn, hs): + super(EventFederationStore, self).__init__(db_conn, hs) + + self.register_background_update_handler( + self.EVENT_AUTH_STATE_ONLY, + self._background_delete_non_state_event_auth, + ) + + hs.get_clock().looping_call( + self._delete_old_forward_extrem_cache, 60 * 60 * 1000 + ) + + def _update_min_depth_for_room_txn(self, txn, room_id, depth): + min_depth = self._get_min_depth_interaction(txn, room_id) + + if min_depth and depth >= min_depth: + return + + self._simple_upsert_txn( + txn, + table="room_depth", + keyvalues={ + "room_id": room_id, + }, + values={ + "min_depth": depth, + }, + ) + + def _handle_mult_prev_events(self, txn, events): + """ + For the given event, update the event edges table and forward and + backward extremities tables. + """ + self._simple_insert_many_txn( + txn, + table="event_edges", + values=[ + { + "event_id": ev.event_id, + "prev_event_id": e_id, + "room_id": ev.room_id, + "is_state": False, + } + for ev in events + for e_id, _ in ev.prev_events + ], + ) + + self._update_backward_extremeties(txn, events) + + def _update_backward_extremeties(self, txn, events): + """Updates the event_backward_extremities tables based on the new/updated + events being persisted. + + This is called for new events *and* for events that were outliers, but + are now being persisted as non-outliers. + + Forward extremities are handled when we first start persisting the events. + """ + events_by_room = {} + for ev in events: + events_by_room.setdefault(ev.room_id, []).append(ev) + + query = ( + "INSERT INTO event_backward_extremities (event_id, room_id)" + " SELECT ?, ? WHERE NOT EXISTS (" + " SELECT 1 FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + " )" + " AND NOT EXISTS (" + " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? " + " AND outlier = ?" + " )" + ) + + txn.executemany(query, [ + (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) + for ev in events for e_id, _ in ev.prev_events + if not ev.internal_metadata.is_outlier() + ]) + + query = ( + "DELETE FROM event_backward_extremities" + " WHERE event_id = ? AND room_id = ?" + ) + txn.executemany( + query, + [ + (ev.event_id, ev.room_id) for ev in events + if not ev.internal_metadata.is_outlier() + ] + ) + + def _delete_old_forward_extrem_cache(self): + def _delete_old_forward_extrem_cache_txn(txn): + # Delete entries older than a month, while making sure we don't delete + # the only entries for a room. + sql = (""" + DELETE FROM stream_ordering_to_exterm + WHERE + room_id IN ( + SELECT room_id + FROM stream_ordering_to_exterm + WHERE stream_ordering > ? + ) AND stream_ordering < ? + """) + txn.execute( + sql, + (self.stream_ordering_month_ago, self.stream_ordering_month_ago,) + ) + return self.runInteraction( + "_delete_old_forward_extrem_cache", + _delete_old_forward_extrem_cache_txn + ) + def clean_room_for_join(self, room_id): return self.runInteraction( "clean_room_for_join", @@ -505,3 +567,52 @@ class EventFederationStore(SQLBaseStore): txn.execute(query, (room_id,)) txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,)) + + @defer.inlineCallbacks + def _background_delete_non_state_event_auth(self, progress, batch_size): + def delete_event_auth(txn): + target_min_stream_id = progress.get("target_min_stream_id_inclusive") + max_stream_id = progress.get("max_stream_id_exclusive") + + if not target_min_stream_id or not max_stream_id: + txn.execute("SELECT COALESCE(MIN(stream_ordering), 0) FROM events") + rows = txn.fetchall() + target_min_stream_id = rows[0][0] + + txn.execute("SELECT COALESCE(MAX(stream_ordering), 0) FROM events") + rows = txn.fetchall() + max_stream_id = rows[0][0] + + min_stream_id = max_stream_id - batch_size + + sql = """ + DELETE FROM event_auth + WHERE event_id IN ( + SELECT event_id FROM events + LEFT JOIN state_events USING (room_id, event_id) + WHERE ? <= stream_ordering AND stream_ordering < ? + AND state_key IS null + ) + """ + + txn.execute(sql, (min_stream_id, max_stream_id,)) + + new_progress = { + "target_min_stream_id_inclusive": target_min_stream_id, + "max_stream_id_exclusive": min_stream_id, + } + + self._background_update_progress_txn( + txn, self.EVENT_AUTH_STATE_ONLY, new_progress + ) + + return min_stream_id >= target_min_stream_id + + result = yield self.runInteraction( + self.EVENT_AUTH_STATE_ONLY, delete_event_auth + ) + + if not result: + yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY) + + defer.returnValue(batch_size) diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 522d0114cb..c22762eb5c 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,131 +14,159 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore, LoggingTransaction from twisted.internet import defer +from synapse.util.async import sleep from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.types import RoomStreamToken from .stream import lower_bound import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) -class EventPushActionsStore(SQLBaseStore): - EPA_HIGHLIGHT_INDEX = "epa_highlight_index" +DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] +DEFAULT_HIGHLIGHT_ACTION = [ + "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"} +] - def __init__(self, hs): - self.stream_ordering_month_ago = None - super(EventPushActionsStore, self).__init__(hs) - self.register_background_index_update( - self.EPA_HIGHLIGHT_INDEX, - index_name="event_push_actions_u_highlight", - table="event_push_actions", - columns=["user_id", "stream_ordering"], - ) +def _serialize_action(actions, is_highlight): + """Custom serializer for actions. This allows us to "compress" common actions. - self.register_background_index_update( - "event_push_actions_highlights_index", - index_name="event_push_actions_highlights_index", - table="event_push_actions", - columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], - where_clause="highlight=1" - ) + We use the fact that most users have the same actions for notifs (and for + highlights). + We store these default actions as the empty string rather than the full JSON. + Since the empty string isn't valid JSON there is no risk of this clashing with + any real JSON actions + """ + if is_highlight: + if actions == DEFAULT_HIGHLIGHT_ACTION: + return "" # We use empty string as the column is non-NULL + else: + if actions == DEFAULT_NOTIF_ACTION: + return "" + return json.dumps(actions) - def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): - """ - Args: - event: the event set actions for - tuples: list of tuples of (user_id, actions) - """ - values = [] - for uid, actions in tuples: - values.append({ - 'room_id': event.room_id, - 'event_id': event.event_id, - 'user_id': uid, - 'actions': json.dumps(actions), - 'stream_ordering': event.internal_metadata.stream_ordering, - 'topological_ordering': event.depth, - 'notif': 1, - 'highlight': 1 if _action_has_highlight(actions) else 0, - }) - - for uid, __ in tuples: - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (event.room_id, uid) - ) - self._simple_insert_many_txn(txn, "event_push_actions", values) - @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) - def get_unread_event_push_actions_by_room_for_user( - self, room_id, user_id, last_read_event_id - ): - def _get_unread_event_push_actions_by_room(txn): - sql = ( - "SELECT stream_ordering, topological_ordering" - " FROM events" - " WHERE room_id = ? AND event_id = ?" - ) - txn.execute( - sql, (room_id, last_read_event_id) - ) - results = txn.fetchall() - if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} - - stream_ordering = results[0][0] - topological_ordering = results[0][1] - token = RoomStreamToken( - topological_ordering, stream_ordering - ) +def _deserialize_action(actions, is_highlight): + """Custom deserializer for actions. This allows us to "compress" common actions + """ + if actions: + return json.loads(actions) - # First get number of notifications. - # We don't need to put a notif=1 clause as all rows always have - # notif=1 - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " user_id = ?" - " AND room_id = ?" - " AND %s" - ) % (lower_bound(token, self.database_engine, inclusive=False),) + if is_highlight: + return DEFAULT_HIGHLIGHT_ACTION + else: + return DEFAULT_NOTIF_ACTION - txn.execute(sql, (user_id, room_id)) - row = txn.fetchone() - notify_count = row[0] if row else 0 - # Now get the number of highlights - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " highlight = 1" - " AND user_id = ?" - " AND room_id = ?" - " AND %s" - ) % (lower_bound(token, self.database_engine, inclusive=False),) +class EventPushActionsWorkerStore(SQLBaseStore): + def __init__(self, db_conn, hs): + super(EventPushActionsWorkerStore, self).__init__(db_conn, hs) - txn.execute(sql, (user_id, room_id)) - row = txn.fetchone() - highlight_count = row[0] if row else 0 + # These get correctly set by _find_stream_orderings_for_times_txn + self.stream_ordering_month_ago = None + self.stream_ordering_day_ago = None + + cur = LoggingTransaction( + db_conn.cursor(), + name="_find_stream_orderings_for_times_txn", + database_engine=self.database_engine, + after_callbacks=[], + exception_callbacks=[], + ) + self._find_stream_orderings_for_times_txn(cur) + cur.close() - return { - "notify_count": notify_count, - "highlight_count": highlight_count, - } + self.find_stream_orderings_looping_call = self._clock.looping_call( + self._find_stream_orderings_for_times, 10 * 60 * 1000 + ) + @cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000) + def get_unread_event_push_actions_by_room_for_user( + self, room_id, user_id, last_read_event_id + ): ret = yield self.runInteraction( "get_unread_event_push_actions_by_room", - _get_unread_event_push_actions_by_room + self._get_unread_counts_by_receipt_txn, + room_id, user_id, last_read_event_id ) defer.returnValue(ret) + def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id, + last_read_event_id): + sql = ( + "SELECT stream_ordering, topological_ordering" + " FROM events" + " WHERE room_id = ? AND event_id = ?" + ) + txn.execute( + sql, (room_id, last_read_event_id) + ) + results = txn.fetchall() + if len(results) == 0: + return {"notify_count": 0, "highlight_count": 0} + + stream_ordering = results[0][0] + topological_ordering = results[0][1] + + return self._get_unread_counts_by_pos_txn( + txn, room_id, user_id, topological_ordering, stream_ordering + ) + + def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering, + stream_ordering): + token = RoomStreamToken( + topological_ordering, stream_ordering + ) + + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + notify_count = row[0] if row else 0 + + txn.execute(""" + SELECT notif_count FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + """, (room_id, user_id, stream_ordering,)) + rows = txn.fetchall() + if rows: + notify_count += rows[0][0] + + # Now get the number of highlights + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " highlight = 1" + " AND user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + highlight_count = row[0] if row else 0 + + return { + "notify_count": notify_count, + "highlight_count": highlight_count, + } + @defer.inlineCallbacks def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): def f(txn): @@ -146,7 +175,7 @@ class EventPushActionsStore(SQLBaseStore): " stream_ordering >= ? AND stream_ordering <= ?" ) txn.execute(sql, (min_stream_ordering, max_stream_ordering)) - return [r[0] for r in txn.fetchall()] + return [r[0] for r in txn] ret = yield self.runInteraction("get_push_action_users_in_range", f) defer.returnValue(ret) @@ -176,7 +205,8 @@ class EventPushActionsStore(SQLBaseStore): # find rooms that have a read receipt in them and return the next # push actions sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions" + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight " " FROM (" " SELECT room_id," " MAX(topological_ordering) as topological_ordering," @@ -217,7 +247,7 @@ class EventPushActionsStore(SQLBaseStore): def get_no_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight " " FROM event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE" @@ -246,7 +276,7 @@ class EventPushActionsStore(SQLBaseStore): "event_id": row[0], "room_id": row[1], "stream_ordering": row[2], - "actions": json.loads(row[3]), + "actions": _deserialize_action(row[3], row[4]), } for row in after_read_receipt + no_read_receipt ] @@ -285,7 +315,7 @@ class EventPushActionsStore(SQLBaseStore): def get_after_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight, e.received_ts" " FROM (" " SELECT room_id," " MAX(topological_ordering) as topological_ordering," @@ -327,7 +357,7 @@ class EventPushActionsStore(SQLBaseStore): def get_no_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight, e.received_ts" " FROM event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE" @@ -357,8 +387,8 @@ class EventPushActionsStore(SQLBaseStore): "event_id": row[0], "room_id": row[1], "stream_ordering": row[2], - "actions": json.loads(row[3]), - "received_ts": row[4], + "actions": _deserialize_action(row[3], row[4]), + "received_ts": row[5], } for row in after_read_receipt + no_read_receipt ] @@ -371,6 +401,290 @@ class EventPushActionsStore(SQLBaseStore): # Now return the first `limit` defer.returnValue(notifs[:limit]) + def add_push_actions_to_staging(self, event_id, user_id_actions): + """Add the push actions for the event to the push action staging area. + + Args: + event_id (str) + user_id_actions (dict[str, list[dict|str])]): A dictionary mapping + user_id to list of push actions, where an action can either be + a string or dict. + + Returns: + Deferred + """ + + if not user_id_actions: + return + + # This is a helper function for generating the necessary tuple that + # can be used to inert into the `event_push_actions_staging` table. + def _gen_entry(user_id, actions): + is_highlight = 1 if _action_has_highlight(actions) else 0 + return ( + event_id, # event_id column + user_id, # user_id column + _serialize_action(actions, is_highlight), # actions column + 1, # notif column + is_highlight, # highlight column + ) + + def _add_push_actions_to_staging_txn(txn): + # We don't use _simple_insert_many here to avoid the overhead + # of generating lists of dicts. + + sql = """ + INSERT INTO event_push_actions_staging + (event_id, user_id, actions, notif, highlight) + VALUES (?, ?, ?, ?, ?) + """ + + txn.executemany(sql, ( + _gen_entry(user_id, actions) + for user_id, actions in user_id_actions.iteritems() + )) + + return self.runInteraction( + "add_push_actions_to_staging", _add_push_actions_to_staging_txn + ) + + @defer.inlineCallbacks + def remove_push_actions_from_staging(self, event_id): + """Called if we failed to persist the event to ensure that stale push + actions don't build up in the DB + + Args: + event_id (str) + """ + + try: + res = yield self._simple_delete( + table="event_push_actions_staging", + keyvalues={ + "event_id": event_id, + }, + desc="remove_push_actions_from_staging", + ) + defer.returnValue(res) + except Exception: + # this method is called from an exception handler, so propagating + # another exception here really isn't helpful - there's nothing + # the caller can do about it. Just log the exception and move on. + logger.exception( + "Error removing push actions after event persistence failure", + ) + + @defer.inlineCallbacks + def _find_stream_orderings_for_times(self): + yield self.runInteraction( + "_find_stream_orderings_for_times", + self._find_stream_orderings_for_times_txn + ) + + def _find_stream_orderings_for_times_txn(self, txn): + logger.info("Searching for stream ordering 1 month ago") + self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 month ago: it's %d", + self.stream_ordering_month_ago + ) + logger.info("Searching for stream ordering 1 day ago") + self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 day ago: it's %d", + self.stream_ordering_day_ago + ) + + def find_first_stream_ordering_after_ts(self, ts): + """Gets the stream ordering corresponding to a given timestamp. + + Specifically, finds the stream_ordering of the first event that was + received on or after the timestamp. This is done by a binary search on + the events table, since there is no index on received_ts, so is + relatively slow. + + Args: + ts (int): timestamp in millis + + Returns: + Deferred[int]: stream ordering of the first event received on/after + the timestamp + """ + return self.runInteraction( + "_find_first_stream_ordering_after_ts_txn", + self._find_first_stream_ordering_after_ts_txn, + ts, + ) + + @staticmethod + def _find_first_stream_ordering_after_ts_txn(txn, ts): + """ + Find the stream_ordering of the first event that was received on or + after a given timestamp. This is relatively slow as there is no index + on received_ts but we can then use this to delete push actions before + this. + + received_ts must necessarily be in the same order as stream_ordering + and stream_ordering is indexed, so we manually binary search using + stream_ordering + + Args: + txn (twisted.enterprise.adbapi.Transaction): + ts (int): timestamp to search for + + Returns: + int: stream ordering + """ + txn.execute("SELECT MAX(stream_ordering) FROM events") + max_stream_ordering = txn.fetchone()[0] + + if max_stream_ordering is None: + return 0 + + # We want the first stream_ordering in which received_ts is greater + # than or equal to ts. Call this point X. + # + # We maintain the invariants: + # + # range_start <= X <= range_end + # + range_start = 0 + range_end = max_stream_ordering + 1 + + # Given a stream_ordering, look up the timestamp at that + # stream_ordering. + # + # The array may be sparse (we may be missing some stream_orderings). + # We treat the gaps as the same as having the same value as the + # preceding entry, because we will pick the lowest stream_ordering + # which satisfies our requirement of received_ts >= ts. + # + # For example, if our array of events indexed by stream_ordering is + # [10, <none>, 20], we should treat this as being equivalent to + # [10, 10, 20]. + # + sql = ( + "SELECT received_ts FROM events" + " WHERE stream_ordering <= ?" + " ORDER BY stream_ordering DESC" + " LIMIT 1" + ) + + while range_end - range_start > 0: + middle = (range_end + range_start) // 2 + txn.execute(sql, (middle,)) + row = txn.fetchone() + if row is None: + # no rows with stream_ordering<=middle + range_start = middle + 1 + continue + + middle_ts = row[0] + if ts > middle_ts: + # we got a timestamp lower than the one we were looking for. + # definitely need to look higher: X > middle. + range_start = middle + 1 + else: + # we got a timestamp higher than (or the same as) the one we + # were looking for. We aren't yet sure about the point we + # looked up, but we can be sure that X <= middle. + range_end = middle + + return range_end + + +class EventPushActionsStore(EventPushActionsWorkerStore): + EPA_HIGHLIGHT_INDEX = "epa_highlight_index" + + def __init__(self, db_conn, hs): + super(EventPushActionsStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + self.EPA_HIGHLIGHT_INDEX, + index_name="event_push_actions_u_highlight", + table="event_push_actions", + columns=["user_id", "stream_ordering"], + ) + + self.register_background_index_update( + "event_push_actions_highlights_index", + index_name="event_push_actions_highlights_index", + table="event_push_actions", + columns=["user_id", "room_id", "topological_ordering", "stream_ordering"], + where_clause="highlight=1" + ) + + self._doing_notif_rotation = False + self._rotate_notif_loop = self._clock.looping_call( + self._rotate_notifs, 30 * 60 * 1000 + ) + + def _set_push_actions_for_event_and_users_txn(self, txn, events_and_contexts, + all_events_and_contexts): + """Handles moving push actions from staging table to main + event_push_actions table for all events in `events_and_contexts`. + + Also ensures that all events in `all_events_and_contexts` are removed + from the push action staging area. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + """ + + sql = """ + INSERT INTO event_push_actions ( + room_id, event_id, user_id, actions, stream_ordering, + topological_ordering, notif, highlight + ) + SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight + FROM event_push_actions_staging + WHERE event_id = ? + """ + + if events_and_contexts: + txn.executemany(sql, ( + ( + event.room_id, event.internal_metadata.stream_ordering, + event.depth, event.event_id, + ) + for event, _ in events_and_contexts + )) + + for event, _ in events_and_contexts: + user_ids = self._simple_select_onecol_txn( + txn, + table="event_push_actions_staging", + keyvalues={ + "event_id": event.event_id, + }, + retcol="user_id", + ) + + for uid in user_ids: + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (event.room_id, uid,) + ) + + # Now we delete the staging area for *all* events that were being + # persisted. + txn.executemany( + "DELETE FROM event_push_actions_staging WHERE event_id = ?", + ( + (event.event_id,) + for event, _ in all_events_and_contexts + ) + ) + @defer.inlineCallbacks def get_push_actions_for_user(self, user_id, before=None, limit=50, only_highlight=False): @@ -392,7 +706,7 @@ class EventPushActionsStore(SQLBaseStore): sql = ( "SELECT epa.event_id, epa.room_id," " epa.stream_ordering, epa.topological_ordering," - " epa.actions, epa.profile_tag, e.received_ts" + " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" " FROM event_push_actions epa, events e" " WHERE epa.event_id = e.event_id" " AND epa.user_id = ? %s" @@ -407,7 +721,7 @@ class EventPushActionsStore(SQLBaseStore): "get_push_actions_for_user", f ) for pa in push_actions: - pa["actions"] = json.loads(pa["actions"]) + pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) defer.returnValue(push_actions) @defer.inlineCallbacks @@ -448,7 +762,7 @@ class EventPushActionsStore(SQLBaseStore): ) def _remove_old_push_actions_before_txn(self, txn, room_id, user_id, - topological_ordering): + topological_ordering, stream_ordering): """ Purges old push actions for a user and room before a given topological_ordering. @@ -479,65 +793,140 @@ class EventPushActionsStore(SQLBaseStore): txn.execute( "DELETE FROM event_push_actions " " WHERE user_id = ? AND room_id = ? AND " - " topological_ordering < ?" + " topological_ordering <= ?" " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", (user_id, room_id, topological_ordering, self.stream_ordering_month_ago) ) + txn.execute(""" + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, (room_id, user_id, stream_ordering)) + @defer.inlineCallbacks - def _find_stream_orderings_for_times(self): - yield self.runInteraction( - "_find_stream_orderings_for_times", - self._find_stream_orderings_for_times_txn - ) + def _rotate_notifs(self): + if self._doing_notif_rotation or self.stream_ordering_day_ago is None: + return + self._doing_notif_rotation = True - def _find_stream_orderings_for_times_txn(self, txn): - logger.info("Searching for stream ordering 1 month ago") - self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 - ) - logger.info( - "Found stream ordering 1 month ago: it's %d", - self.stream_ordering_month_ago + try: + while True: + logger.info("Rotating notifications") + + caught_up = yield self.runInteraction( + "_rotate_notifs", + self._rotate_notifs_txn + ) + if caught_up: + break + yield sleep(5) + finally: + self._doing_notif_rotation = False + + def _rotate_notifs_txn(self, txn): + """Archives older notifications into event_push_summary. Returns whether + the archiving process has caught up or not. + """ + + old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", ) - def _find_first_stream_ordering_after_ts_txn(self, txn, ts): - """ - Find the stream_ordering of the first event that was received after - a given timestamp. This is relatively slow as there is no index on - received_ts but we can then use this to delete push actions before - this. + # We don't to try and rotate millions of rows at once, so we cap the + # maximum stream ordering we'll rotate before. + txn.execute(""" + SELECT stream_ordering FROM event_push_actions + WHERE stream_ordering > ? + ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000 + """, (old_rotate_stream_ordering,)) + stream_row = txn.fetchone() + if stream_row: + offset_stream_ordering, = stream_row + rotate_to_stream_ordering = min( + self.stream_ordering_day_ago, offset_stream_ordering + ) + caught_up = offset_stream_ordering >= self.stream_ordering_day_ago + else: + rotate_to_stream_ordering = self.stream_ordering_day_ago + caught_up = True - received_ts must necessarily be in the same order as stream_ordering - and stream_ordering is indexed, so we manually binary search using - stream_ordering + logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) + + self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) + + # We have caught up iff we were limited by `stream_ordering_day_ago` + return caught_up + + def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): + old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + # Calculate the new counts that should be upserted into event_push_summary + sql = """ + SELECT user_id, room_id, + coalesce(old.notif_count, 0) + upd.notif_count, + upd.stream_ordering, + old.user_id + FROM ( + SELECT user_id, room_id, count(*) as notif_count, + max(stream_ordering) as stream_ordering + FROM event_push_actions + WHERE ? <= stream_ordering AND stream_ordering < ? + AND highlight = 0 + GROUP BY user_id, room_id + ) AS upd + LEFT JOIN event_push_summary AS old USING (user_id, room_id) """ - txn.execute("SELECT MAX(stream_ordering) FROM events") - max_stream_ordering = txn.fetchone()[0] - if max_stream_ordering is None: - return 0 + txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,)) + rows = txn.fetchall() + + logger.info("Rotating notifications, handling %d rows", len(rows)) + + # If the `old.user_id` above is NULL then we know there isn't already an + # entry in the table, so we simply insert it. Otherwise we update the + # existing table. + self._simple_insert_many_txn( + txn, + table="event_push_summary", + values=[ + { + "user_id": row[0], + "room_id": row[1], + "notif_count": row[2], + "stream_ordering": row[3], + } + for row in rows if row[4] is None + ] + ) - range_start = 0 - range_end = max_stream_ordering + txn.executemany( + """ + UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? + WHERE user_id = ? AND room_id = ? + """, + ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None) + ) - sql = ( - "SELECT received_ts FROM events" - " WHERE stream_ordering > ?" - " ORDER BY stream_ordering" - " LIMIT 1" + txn.execute( + "DELETE FROM event_push_actions" + " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", + (old_rotate_stream_ordering, rotate_to_stream_ordering,) ) - while range_end - range_start > 1: - middle = int((range_end + range_start) / 2) - txn.execute(sql, (middle,)) - middle_ts = txn.fetchone()[0] - if ts > middle_ts: - range_start = middle - else: - range_end = middle + logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) - return range_end + txn.execute( + "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", + (rotate_to_stream_ordering,) + ) def _action_has_highlight(actions): diff --git a/synapse/storage/events.py b/synapse/storage/events.py index c88f689d3a..05cde96afc 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,59 +13,61 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from twisted.internet import defer, reactor +from collections import OrderedDict, deque, namedtuple +from functools import wraps +import itertools +import logging -from synapse.events import FrozenEvent, USE_FROZEN_DICTS -from synapse.events.utils import prune_event +import simplejson as json +from twisted.internet import defer +from synapse.storage.events_worker import EventsWorkerStore from synapse.util.async import ObservableDeferred +from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.logcontext import ( - preserve_fn, PreserveLoggingContext, preserve_context_over_deferred + PreserveLoggingContext, make_deferred_yieldable, ) from synapse.util.logutils import log_function from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError -from synapse.state import resolve_events -from synapse.util.caches.descriptors import cached - -from canonicaljson import encode_canonical_json -from collections import deque, namedtuple, OrderedDict -from functools import wraps - -import synapse +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.types import get_domain_from_id import synapse.metrics - -import logging -import math -import ujson as json +# these are only included to make the type annotations work +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) persist_event_counter = metrics.register_counter("persisted_events") +event_counter = metrics.register_counter( + "persisted_events_sep", labels=["type", "origin_type", "origin_entity"] +) - -def encode_json(json_object): - if USE_FROZEN_DICTS: - # ujson doesn't like frozen_dicts - return encode_canonical_json(json_object) - else: - return json.dumps(json_object, ensure_ascii=False) +# The number of times we are recalculating the current state +state_delta_counter = metrics.register_counter( + "state_delta", +) +# The number of times we are recalculating state when there is only a +# single forward extremity +state_delta_single_event_counter = metrics.register_counter( + "state_delta_single_event", +) +# The number of times we are reculating state when we could have resonably +# calculated the delta when we calculated the state for an event we were +# persisting. +state_delta_reuse_delta_counter = metrics.register_counter( + "state_delta_reuse_delta", +) -# These values are used in the `enqueus_event` and `_do_fetch` methods to -# control how we batch/bulk fetch events from the database. -# The values are plucked out of thing air to make initial sync run faster -# on jki.re -# TODO: Make these configurable. -EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events -EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events -EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events +def encode_json(json_object): + return frozendict_json_encoder.encode(json_object) class _EventPeristenceQueue(object): @@ -82,15 +85,30 @@ class _EventPeristenceQueue(object): def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. + + NB: due to the normal usage pattern of this method, it does *not* + follow the synapse logcontext rules, and leaves the logcontext in + place whether or not the returned deferred is ready. + + Args: + room_id (str): + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + + Returns: + defer.Deferred: a deferred which will resolve once the events are + persisted. Runs its callbacks *without* a logcontext. """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: + # if the last item in the queue has the same `backfilled` setting, + # we can just add these new events to that item. end_item = queue[-1] if end_item.backfilled == backfilled: end_item.events_and_contexts.extend(events_and_contexts) return end_item.deferred.observe() - deferred = ObservableDeferred(defer.Deferred()) + deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True) queue.append(self._EventPersistQueueItem( events_and_contexts=events_and_contexts, @@ -103,11 +121,11 @@ class _EventPeristenceQueue(object): def handle_queue(self, room_id, per_item_callback): """Attempts to handle the queue for a room if not already being handled. - The given callback will be invoked with for each item in the queue,1 + The given callback will be invoked with for each item in the queue, of type _EventPersistQueueItem. The per_item_callback will continuously be called with new items, unless the queue becomnes empty. The return value of the function will be given to the deferreds waiting on the item, - exceptions will be passed to the deferres as well. + exceptions will be passed to the deferreds as well. This function should therefore be called whenever anything is added to the queue. @@ -126,18 +144,25 @@ class _EventPeristenceQueue(object): try: queue = self._get_drainining_queue(room_id) for item in queue: + # handle_queue_loop runs in the sentinel logcontext, so + # there is no need to preserve_fn when running the + # callbacks on the deferred. try: ret = yield per_item_callback(item) item.deferred.callback(ret) - except Exception as e: - item.deferred.errback(e) + except Exception: + item.deferred.errback() finally: queue = self._event_persist_queues.pop(room_id, None) if queue: self._event_persist_queues[room_id] = queue self._currently_persisting_rooms.discard(room_id) - preserve_fn(handle_queue_loop)() + # set handle_queue_loop off on the background. We don't want to + # attribute work done in it to the current request, so we drop the + # logcontext altogether. + with PreserveLoggingContext(): + handle_queue_loop() def _get_drainining_queue(self, room_id): queue = self._event_persist_queues.setdefault(room_id, deque()) @@ -173,13 +198,12 @@ def _retry_on_integrity_error(func): return f -class EventsStore(SQLBaseStore): +class EventsStore(EventsWorkerStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" - def __init__(self, hs): - super(EventsStore, self).__init__(hs) - self._clock = hs.get_clock() + def __init__(self, db_conn, hs): + super(EventsStore, self).__init__(db_conn, hs) self.register_background_update_handler( self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts ) @@ -196,8 +220,22 @@ class EventsStore(SQLBaseStore): where_clause="contains_url = true AND outlier = false", ) + # an event_id index on event_search is useful for the purge_history + # api. Plus it means we get to enforce some integrity with a UNIQUE + # clause + self.register_background_index_update( + "event_search_event_id_idx", + index_name="event_search_event_id_idx", + table="event_search", + columns=["event_id"], + unique=True, + psql_only=True, + ) + self._event_persist_queue = _EventPeristenceQueue() + self._state_resolution_handler = hs.get_state_resolution_handler() + def persist_events(self, events_and_contexts, backfilled=False): """ Write events to the database @@ -210,23 +248,34 @@ class EventsStore(SQLBaseStore): partitioned.setdefault(event.room_id, []).append((event, ctx)) deferreds = [] - for room_id, evs_ctxs in partitioned.items(): - d = preserve_fn(self._event_persist_queue.add_to_queue)( + for room_id, evs_ctxs in partitioned.iteritems(): + d = self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled, ) deferreds.append(d) - for room_id in partitioned.keys(): + for room_id in partitioned: self._maybe_start_persisting(room_id) - return preserve_context_over_deferred( + return make_deferred_yieldable( defer.gatherResults(deferreds, consumeErrors=True) ) @defer.inlineCallbacks @log_function def persist_event(self, event, context, backfilled=False): + """ + + Args: + event (EventBase): + context (EventContext): + backfilled (bool): + + Returns: + Deferred: resolves to (int, int): the stream ordering of ``event``, + and the stream ordering of the latest persisted event + """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled, @@ -234,7 +283,7 @@ class EventsStore(SQLBaseStore): self._maybe_start_persisting(event.room_id) - yield preserve_context_over_deferred(deferred) + yield make_deferred_yieldable(deferred) max_persisted_id = yield self._stream_id_gen.get_current_token() defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id)) @@ -242,10 +291,11 @@ class EventsStore(SQLBaseStore): def _maybe_start_persisting(self, room_id): @defer.inlineCallbacks def persisting_queue(item): - yield self._persist_events( - item.events_and_contexts, - backfilled=item.backfilled, - ) + with Measure(self._clock, "persist_events"): + yield self._persist_events( + item.events_and_contexts, + backfilled=item.backfilled, + ) self._event_persist_queue.handle_queue(room_id, persisting_queue) @@ -253,6 +303,16 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks def _persist_events(self, events_and_contexts, backfilled=False, delete_existing=False): + """Persist events to db + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + delete_existing (bool): + + Returns: + Deferred: resolves when the events have been persisted + """ if not events_and_contexts: return @@ -282,8 +342,20 @@ class EventsStore(SQLBaseStore): # NB: Assumes that we are only persisting events for one room # at a time. + + # map room_id->list[event_ids] giving the new forward + # extremities in each room new_forward_extremeties = {} + + # map room_id->(type,state_key)->event_id tracking the full + # state in each room after adding these events current_state_for_room = {} + + # map room_id->(to_delete, to_insert) where each entry is + # a map (type,key)->event_id giving the state delta in each + # room + state_delta_for_room = {} + if not backfilled: with Measure(self._clock, "_calculate_state_and_extrem"): # Work out the new "current state" for each room. @@ -295,7 +367,7 @@ class EventsStore(SQLBaseStore): (event, context) ) - for room_id, ev_ctx_rm in events_by_room.items(): + for room_id, ev_ctx_rm in events_by_room.iteritems(): # Work out new extremities by recursively adding and removing # the new events. latest_event_ids = yield self.get_latest_event_ids_in_room( @@ -305,17 +377,64 @@ class EventsStore(SQLBaseStore): room_id, ev_ctx_rm, latest_event_ids ) - if new_latest_event_ids == set(latest_event_ids): + latest_event_ids = set(latest_event_ids) + if new_latest_event_ids == latest_event_ids: # No change in extremities, so no change in state continue new_forward_extremeties[room_id] = new_latest_event_ids - state = yield self._calculate_state_delta( - room_id, ev_ctx_rm, new_latest_event_ids + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 + ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_events) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + continue + + state_delta_counter.inc() + if len(new_latest_event_ids) == 1: + state_delta_single_event_counter.inc() + + # This is a fairly handwavey check to see if we could + # have guessed what the delta would have been when + # processing one of these events. + # What we're interested in is if the latest extremities + # were the same when we created the event as they are + # now. When this server creates a new event (as opposed + # to receiving it over federation) it will use the + # forward extremities as the prev_events, so we can + # guess this by looking at the prev_events and checking + # if they match the current forward extremities. + for ev, _ in ev_ctx_rm: + prev_event_ids = set(e for e, _ in ev.prev_events) + if latest_event_ids == prev_event_ids: + state_delta_reuse_delta_counter.inc() + break + + logger.info( + "Calculating state delta for room %s", room_id, ) - if state: - current_state_for_room[room_id] = state + current_state = yield self._get_new_state_after_events( + room_id, + ev_ctx_rm, + latest_event_ids, + new_latest_event_ids, + ) + if current_state is not None: + current_state_for_room[room_id] = current_state + delta = yield self._calculate_state_delta( + room_id, current_state, + ) + if delta is not None: + state_delta_for_room[room_id] = delta yield self.runInteraction( "persist_events", @@ -323,10 +442,35 @@ class EventsStore(SQLBaseStore): events_and_contexts=chunk, backfilled=backfilled, delete_existing=delete_existing, - current_state_for_room=current_state_for_room, + state_delta_for_room=state_delta_for_room, new_forward_extremeties=new_forward_extremeties, ) persist_event_counter.inc_by(len(chunk)) + synapse.metrics.event_persisted_position.set( + chunk[-1][0].internal_metadata.stream_ordering, + ) + for event, context in chunk: + if context.app_service: + origin_type = "local" + origin_entity = context.app_service.id + elif self.hs.is_mine_id(event.sender): + origin_type = "local" + origin_entity = "*client*" + else: + origin_type = "remote" + origin_entity = get_domain_from_id(event.sender) + + event_counter.inc(event.type, origin_type, origin_entity) + + for room_id, new_state in current_state_for_room.iteritems(): + self.get_current_state_ids.prefill( + (room_id, ), new_state + ) + + for room_id, latest_event_ids in new_forward_extremeties.iteritems(): + self.get_latest_event_ids_in_room.prefill( + (room_id,), list(latest_event_ids) + ) @defer.inlineCallbacks def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): @@ -370,71 +514,137 @@ class EventsStore(SQLBaseStore): defer.returnValue(new_latest_event_ids) @defer.inlineCallbacks - def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): - """Calculate the new state deltas for a room. + def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids, + new_latest_event_ids): + """Calculate the current state dict after adding some new events to + a room - Assumes that we are only persisting events for one room at a time. + Args: + room_id (str): + room to which the events are being added. Used for logging etc + + events_context (list[(EventBase, EventContext)]): + events and contexts which are being added to the room + + old_latest_event_ids (iterable[str]): + the old forward extremities for the room. + + new_latest_event_ids (iterable[str]): + the new forward extremities for the room. Returns: - 2-tuple (to_delete, to_insert) where both are state dicts, i.e. - (type, state_key) -> event_id. `to_delete` are the entries to - first be deleted from current_state_events, `to_insert` are entries - to insert. - May return None if there are no changes to be applied. + Deferred[dict[(str,str), str]|None]: + None if there are no changes to the room state, or + a dict of (type, state_key) -> event_id]. """ - # Now we need to work out the different state sets for - # each state extremities - state_sets = [] - missing_event_ids = [] - was_updated = False + + if not new_latest_event_ids: + return + + # map from state_group to ((type, key) -> event_id) state map + state_groups_map = {} + for ev, ctx in events_context: + if ctx.state_group is None: + # I don't think this can happen, but let's double-check + raise Exception( + "Context for new extremity event %s has no state " + "group" % (ev.event_id, ), + ) + + if ctx.state_group in state_groups_map: + continue + + state_groups_map[ctx.state_group] = ctx.current_state_ids + + # We need to map the event_ids to their state groups. First, let's + # check if the event is one we're persisting, in which case we can + # pull the state group from its context. + # Otherwise we need to pull the state group from the database. + + # Set of events we need to fetch groups for. (We know none of the old + # extremities are going to be in events_context). + missing_event_ids = set(old_latest_event_ids) + + event_id_to_state_group = {} for event_id in new_latest_event_ids: - # First search in the list of new events we're adding, - # and then use the current state from that + # First search in the list of new events we're adding. for ev, ctx in events_context: if event_id == ev.event_id: - if ctx.current_state_ids is None: - raise Exception("Unknown current state") - state_sets.append(ctx.current_state_ids) - if ctx.delta_ids or hasattr(ev, "state_key"): - was_updated = True + event_id_to_state_group[event_id] = ctx.state_group break else: # If we couldn't find it, then we'll need to pull # the state from the database - was_updated = True - missing_event_ids.append(event_id) + missing_event_ids.add(event_id) if missing_event_ids: - # Now pull out the state for any missing events from DB + # Now pull out the state groups for any missing events from DB event_to_groups = yield self._get_state_group_for_events( missing_event_ids, ) + event_id_to_state_group.update(event_to_groups) - groups = set(event_to_groups.values()) - group_to_state = yield self._get_state_for_groups(groups) + # State groups of old_latest_event_ids + old_state_groups = set( + event_id_to_state_group[evid] for evid in old_latest_event_ids + ) - state_sets.extend(group_to_state.values()) + # State groups of new_latest_event_ids + new_state_groups = set( + event_id_to_state_group[evid] for evid in new_latest_event_ids + ) - if not new_latest_event_ids: - current_state = {} - elif was_updated: - current_state = yield resolve_events( - state_sets, - state_map_factory=lambda ev_ids: self.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ), - ) - else: + # If they old and new groups are the same then we don't need to do + # anything. + if old_state_groups == new_state_groups: return - existing_state_rows = yield self._simple_select_list( - table="current_state_events", - keyvalues={"room_id": room_id}, - retcols=["event_id", "type", "state_key"], - desc="_calculate_state_delta", + # Now that we have calculated new_state_groups we need to get + # their state IDs so we can resolve to a single state set. + missing_state = new_state_groups - set(state_groups_map) + if missing_state: + group_to_state = yield self._get_state_for_groups(missing_state) + state_groups_map.update(group_to_state) + + if len(new_state_groups) == 1: + # If there is only one state group, then we know what the current + # state is. + defer.returnValue(state_groups_map[new_state_groups.pop()]) + + # Ok, we need to defer to the state handler to resolve our state sets. + + def get_events(ev_ids): + return self.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ) + + state_groups = { + sg: state_groups_map[sg] for sg in new_state_groups + } + + events_map = {ev.event_id: ev for ev, _ in events_context} + logger.debug("calling resolve_state_groups from preserve_events") + res = yield self._state_resolution_handler.resolve_state_groups( + room_id, state_groups, events_map, get_events ) - existing_events = set(row["event_id"] for row in existing_state_rows) + defer.returnValue(res.state) + + @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, current_state): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + 2-tuple (to_delete, to_insert) where both are state dicts, + i.e. (type, state_key) -> event_id. `to_delete` are the entries to + first be deleted from current_state_events, `to_insert` are entries + to insert. + """ + existing_state = yield self.get_current_state_ids(room_id) + + existing_events = set(existing_state.itervalues()) new_events = set(ev_id for ev_id in current_state.itervalues()) changed_events = existing_events ^ new_events @@ -442,9 +652,8 @@ class EventsStore(SQLBaseStore): return to_delete = { - (row["type"], row["state_key"]): row["event_id"] - for row in existing_state_rows - if row["event_id"] in changed_events + key: ev_id for key, ev_id in existing_state.iteritems() + if ev_id in changed_events } events_to_insert = (new_events - existing_events) to_insert = { @@ -454,77 +663,104 @@ class EventsStore(SQLBaseStore): defer.returnValue((to_delete, to_insert)) - @defer.inlineCallbacks - def get_event(self, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False, - allow_none=False): - """Get an event from the database by event_id. + @log_function + def _persist_events_txn(self, txn, events_and_contexts, backfilled, + delete_existing=False, state_delta_for_room={}, + new_forward_extremeties={}): + """Insert some number of room events into the necessary database tables. + + Rejected events are only inserted into the events table, the events_json table, + and the rejections table. Things reading from those table will need to check + whether the event was rejected. Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if - False throw an exception. + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): + events to persist + backfilled (bool): True if the events were backfilled + delete_existing (bool): True to purge existing table rows for the + events from the database. This is useful when retrying due to + IntegrityError. + state_delta_for_room (dict[str, (list[str], list[str])]): + The current-state delta for each room. For each room, a tuple + (to_delete, to_insert), being a list of event ids to be removed + from the current state, and a list of event ids to be added to + the current state. + new_forward_extremeties (dict[str, list[str]]): + The new forward extremities for each room. For each room, a + list of the event ids which are the forward extremities. - Returns: - Deferred : A FrozenEvent. """ - events = yield self._get_events( - [event_id], - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, + all_events_and_contexts = events_and_contexts + + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering + + self._update_current_state_txn(txn, state_delta_for_room, max_stream_order) + + self._update_forward_extremities_txn( + txn, + new_forward_extremities=new_forward_extremeties, + max_stream_order=max_stream_order, ) - if not events and not allow_none: - raise SynapseError(404, "Could not find event %s" % (event_id,)) + # Ensure that we don't have the same event twice. + events_and_contexts = self._filter_events_and_contexts_for_duplicates( + events_and_contexts, + ) - defer.returnValue(events[0] if events else None) + self._update_room_depths_txn( + txn, + events_and_contexts=events_and_contexts, + backfilled=backfilled, + ) - @defer.inlineCallbacks - def get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - """Get events from the database + # _update_outliers_txn filters out any events which have already been + # persisted, and returns the filtered list. + events_and_contexts = self._update_outliers_txn( + txn, + events_and_contexts=events_and_contexts, + ) - Args: - event_ids (list): The event_ids of the events to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. + # From this point onwards the events are only events that we haven't + # seen before. - Returns: - Deferred : Dict from event_id to event. - """ - events = yield self._get_events( - event_ids, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, + if delete_existing: + # For paranoia reasons, we go and delete all the existing entries + # for these events so we can reinsert them. + # This gets around any problems with some tables already having + # entries. + self._delete_existing_rows_txn( + txn, + events_and_contexts=events_and_contexts, + ) + + self._store_event_txn( + txn, + events_and_contexts=events_and_contexts, ) - defer.returnValue({e.event_id: e for e in events}) + # Insert into event_to_state_groups. + self._store_event_state_mappings_txn(txn, events_and_contexts) - @log_function - def _persist_events_txn(self, txn, events_and_contexts, backfilled, - delete_existing=False, current_state_for_room={}, - new_forward_extremeties={}): - """Insert some number of room events into the necessary database tables. + # _store_rejected_events_txn filters out any events which were + # rejected, and returns the filtered list. + events_and_contexts = self._store_rejected_events_txn( + txn, + events_and_contexts=events_and_contexts, + ) - Rejected events are only inserted into the events table, the events_json table, - and the rejections table. Things reading from those table will need to check - whether the event was rejected. + # From this point onwards the events are only ones that weren't + # rejected. - If delete_existing is True then existing events will be purged from the - database before insertion. This is useful when retrying due to IntegrityError. - """ - max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - for room_id, current_state_tuple in current_state_for_room.iteritems(): + self._update_metadata_tables_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + backfilled=backfilled, + ) + + def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order): + for room_id, current_state_tuple in state_delta_by_room.iteritems(): to_delete, to_insert = current_state_tuple txn.executemany( "DELETE FROM current_state_events WHERE event_id = ?", @@ -545,6 +781,29 @@ class EventsStore(SQLBaseStore): ], ) + state_deltas = {key: None for key in to_delete} + state_deltas.update(to_insert) + + self._simple_insert_many_txn( + txn, + table="current_state_delta_stream", + values=[ + { + "stream_id": max_stream_order, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": ev_id, + "prev_event_id": to_delete.get(key, None), + } + for key, ev_id in state_deltas.iteritems() + ] + ) + + self._curr_state_delta_stream_cache.entity_has_changed( + room_id, max_stream_order, + ) + # Invalidate the various caches # Figure out the changes of membership to invalidate the @@ -553,24 +812,34 @@ class EventsStore(SQLBaseStore): # and which we have added, then we invlidate the caches for all # those users. members_changed = set( - state_key for ev_type, state_key in to_delete.iterkeys() - if ev_type == EventTypes.Member - ) - members_changed.update( - state_key for ev_type, state_key in to_insert.iterkeys() + state_key for ev_type, state_key in state_deltas if ev_type == EventTypes.Member ) for member in members_changed: self._invalidate_cache_and_stream( - txn, self.get_rooms_for_user, (member,) + txn, self.get_rooms_for_user_with_stream_ordering, (member,) + ) + + for host in set(get_domain_from_id(u) for u in members_changed): + self._invalidate_cache_and_stream( + txn, self.is_host_joined, (room_id, host) + ) + self._invalidate_cache_and_stream( + txn, self.was_host_joined, (room_id, host) ) self._invalidate_cache_and_stream( txn, self.get_users_in_room, (room_id,) ) - for room_id, new_extrem in new_forward_extremeties.items(): + self._invalidate_cache_and_stream( + txn, self.get_current_state_ids, (room_id,) + ) + + def _update_forward_extremities_txn(self, txn, new_forward_extremities, + max_stream_order): + for room_id, new_extrem in new_forward_extremities.iteritems(): self._simple_delete_txn( txn, table="event_forward_extremities", @@ -588,7 +857,7 @@ class EventsStore(SQLBaseStore): "event_id": ev_id, "room_id": room_id, } - for room_id, new_extrem in new_forward_extremeties.items() + for room_id, new_extrem in new_forward_extremities.iteritems() for ev_id in new_extrem ], ) @@ -605,13 +874,22 @@ class EventsStore(SQLBaseStore): "event_id": event_id, "stream_ordering": max_stream_order, } - for room_id, new_extrem in new_forward_extremeties.items() + for room_id, new_extrem in new_forward_extremities.iteritems() for event_id in new_extrem ] ) - # Ensure that we don't have the same event twice. - # Pick the earliest non-outlier if there is one, else the earliest one. + @classmethod + def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): + """Ensure that we don't have the same event twice. + + Pick the earliest non-outlier if there is one, else the earliest one. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + Returns: + list[(EventBase, EventContext)]: filtered list + """ new_events_and_contexts = OrderedDict() for event, context in events_and_contexts: prev_event_context = new_events_and_contexts.get(event.event_id) @@ -624,9 +902,17 @@ class EventsStore(SQLBaseStore): new_events_and_contexts[event.event_id] = (event, context) else: new_events_and_contexts[event.event_id] = (event, context) + return new_events_and_contexts.values() - events_and_contexts = new_events_and_contexts.values() + def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): + """Update min_depth for each room + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + backfilled (bool): True if the events were backfilled + """ depth_updates = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids @@ -642,9 +928,24 @@ class EventsStore(SQLBaseStore): event.depth, depth_updates.get(event.room_id, event.depth) ) - for room_id, depth in depth_updates.items(): + for room_id, depth in depth_updates.iteritems(): self._update_min_depth_for_room_txn(txn, room_id, depth) + def _update_outliers_txn(self, txn, events_and_contexts): + """Update any outliers with new event info. + + This turns outliers into ex-outliers (unless the new event was + rejected). + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without events which + are already in the events table. + """ txn.execute( "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( ",".join(["?"] * len(events_and_contexts)), @@ -654,34 +955,30 @@ class EventsStore(SQLBaseStore): have_persisted = { event_id: outlier - for event_id, outlier in txn.fetchall() + for event_id, outlier in txn } to_remove = set() for event, context in events_and_contexts: - if context.rejected: - # If the event is rejected then we don't care if the event - # was an outlier or not. - if event.event_id in have_persisted: - # If we have already seen the event then ignore it. - to_remove.add(event) - continue - if event.event_id not in have_persisted: continue to_remove.add(event) + if context.rejected: + # If the event is rejected then we don't care if the event + # was an outlier or not. + continue + outlier_persisted = have_persisted[event.event_id] if not event.internal_metadata.is_outlier() and outlier_persisted: # We received a copy of an event that we had already stored as # an outlier in the database. We now have some state at that # so we need to update the state_groups table with that state. - # insert into the state_group, state_groups_state and - # event_to_state_groups tables. + # insert into event_to_state_groups. try: - self._store_mult_state_groups_txn(txn, ((event, context),)) + self._store_event_state_mappings_txn(txn, ((event, context),)) except Exception: logger.exception("") raise @@ -726,37 +1023,19 @@ class EventsStore(SQLBaseStore): # event isn't an outlier any more. self._update_backward_extremeties(txn, [event]) - events_and_contexts = [ + return [ ec for ec in events_and_contexts if ec[0] not in to_remove ] + @classmethod + def _delete_existing_rows_txn(cls, txn, events_and_contexts): if not events_and_contexts: - # Make sure we don't pass an empty list to functions that expect to - # be storing at least one element. + # nothing to do here return - # From this point onwards the events are only events that we haven't - # seen before. - - def event_dict(event): - return { - k: v - for k, v in event.get_dict().items() - if k not in [ - "redacted", - "redacted_because", - ] - } - - if delete_existing: - # For paranoia reasons, we go and delete all the existing entries - # for these events so we can reinsert them. - # This gets around any problems with some tables already having - # entries. - - logger.info("Deleting existing") + logger.info("Deleting existing") - for table in ( + for table in ( "events", "event_auth", "event_json", @@ -779,11 +1058,30 @@ class EventsStore(SQLBaseStore): "redactions", "room_memberships", "topics" - ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts] - ) + ): + txn.executemany( + "DELETE FROM %s WHERE event_id = ?" % (table,), + [(ev.event_id,) for ev, _ in events_and_contexts] + ) + + def _store_event_txn(self, txn, events_and_contexts): + """Insert new events into the event and event_json tables + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + """ + + if not events_and_contexts: + # nothing to do here + return + + def event_dict(event): + d = event.get_dict() + d.pop("redacted", None) + d.pop("redacted_because", None) + return d self._simple_insert_many_txn( txn, @@ -827,6 +1125,19 @@ class EventsStore(SQLBaseStore): ], ) + def _store_rejected_events_txn(self, txn, events_and_contexts): + """Add rows to the 'rejections' table for received events which were + rejected + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without the rejected + events. + """ # Remove the rejected events from the list now that we've added them # to the events table and the events_json table. to_remove = set() @@ -838,24 +1149,37 @@ class EventsStore(SQLBaseStore): ) to_remove.add(event) - events_and_contexts = [ + return [ ec for ec in events_and_contexts if ec[0] not in to_remove ] + def _update_metadata_tables_txn(self, txn, events_and_contexts, + all_events_and_contexts, backfilled): + """Update all the miscellaneous tables for new events + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + all_events_and_contexts (list[(EventBase, EventContext)]): all + events that we were going to persist. This includes events + we've already persisted, etc, that wouldn't appear in + events_and_context. + backfilled (bool): True if the events were backfilled + """ + + # Insert all the push actions into the event_push_actions table. + self._set_push_actions_for_event_and_users_txn( + txn, + events_and_contexts=events_and_contexts, + all_events_and_contexts=all_events_and_contexts, + ) + if not events_and_contexts: - # Make sure we don't pass an empty list to functions that expect to - # be storing at least one element. + # nothing to do here return - # From this point onwards the events are only ones that weren't rejected. - for event, context in events_and_contexts: - # Insert all the push actions into the event_push_actions table. - if context.push_actions: - self._set_push_actions_for_event_and_users_txn( - txn, event, context.push_actions - ) - if event.type == EventTypes.Redaction and event.redacts is not None: # Remove the entries in the event_push_actions table for the # redacted event. @@ -874,13 +1198,10 @@ class EventsStore(SQLBaseStore): } for event, _ in events_and_contexts for auth_id, _ in event.auth_events + if event.is_state() ], ) - # Insert into the state_groups, state_groups_state, and - # event_to_state_groups tables. - self._store_mult_state_groups_txn(txn, events_and_contexts) - # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( @@ -967,13 +1288,6 @@ class EventsStore(SQLBaseStore): # Prefill the event cache self._add_to_cache(txn, events_and_contexts) - if backfilled: - # Backfilled events come before the current state so we don't need - # to update the current state table - return - - return - def _add_to_cache(self, txn, events_and_contexts): to_prefill = [] @@ -1037,13 +1351,49 @@ class EventsStore(SQLBaseStore): defer.returnValue(set(r["event_id"] for r in rows)) - def have_events(self, event_ids): + @defer.inlineCallbacks + def have_seen_events(self, event_ids): """Given a list of event ids, check if we have already processed them. + Args: + event_ids (iterable[str]): + Returns: - dict: Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps to - None. + Deferred[set[str]]: The events we have already seen. + """ + results = set() + + def have_seen_events_txn(txn, chunk): + sql = ( + "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" + % (",".join("?" * len(chunk)), ) + ) + txn.execute(sql, chunk) + for (event_id, ) in txn: + results.add(event_id) + + # break the input up into chunks of 100 + input_iterator = iter(event_ids) + for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), + []): + yield self.runInteraction( + "have_seen_events", + have_seen_events_txn, + chunk, + ) + defer.returnValue(results) + + def get_seen_events_with_rejections(self, event_ids): + """Given a list of event ids, check if we rejected them. + + Args: + event_ids (list[str]) + + Returns: + Deferred[dict[str, str|None): + Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps + to None. """ if not event_ids: return defer.succeed({}) @@ -1065,280 +1415,7 @@ class EventsStore(SQLBaseStore): return res - return self.runInteraction( - "have_events", f, - ) - - @defer.inlineCallbacks - def _get_events(self, event_ids, check_redacted=True, - get_prev_content=False, allow_rejected=False): - if not event_ids: - defer.returnValue([]) - - event_id_list = event_ids - event_ids = set(event_ids) - - event_entry_map = self._get_events_from_cache( - event_ids, - allow_rejected=allow_rejected, - ) - - missing_events_ids = [e for e in event_ids if e not in event_entry_map] - - if missing_events_ids: - missing_events = yield self._enqueue_events( - missing_events_ids, - check_redacted=check_redacted, - allow_rejected=allow_rejected, - ) - - event_entry_map.update(missing_events) - - events = [] - for event_id in event_id_list: - entry = event_entry_map.get(event_id, None) - if not entry: - continue - - if allow_rejected or not entry.event.rejected_reason: - if check_redacted and entry.redacted_event: - event = entry.redacted_event - else: - event = entry.event - - events.append(event) - - if get_prev_content: - if "replaces_state" in event.unsigned: - prev = yield self.get_event( - event.unsigned["replaces_state"], - get_prev_content=False, - allow_none=True, - ) - if prev: - event.unsigned = dict(event.unsigned) - event.unsigned["prev_content"] = prev.content - event.unsigned["prev_sender"] = prev.sender - - defer.returnValue(events) - - def _invalidate_get_event_cache(self, event_id): - self._get_event_cache.invalidate((event_id,)) - - def _get_events_from_cache(self, events, allow_rejected): - event_map = {} - - for event_id in events: - ret = self._get_event_cache.get((event_id,), None) - if not ret: - continue - - if allow_rejected or not ret.event.rejected_reason: - event_map[event_id] = ret - else: - event_map[event_id] = None - - return event_map - - def _do_fetch(self, conn): - """Takes a database connection and waits for requests for events from - the _event_fetch_list queue. - """ - event_list = [] - i = 0 - while True: - try: - with self._event_fetch_lock: - event_list = self._event_fetch_list - self._event_fetch_list = [] - - if not event_list: - single_threaded = self.database_engine.single_threaded - if single_threaded or i > EVENT_QUEUE_ITERATIONS: - self._event_fetch_ongoing -= 1 - return - else: - self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) - i += 1 - continue - i = 0 - - event_id_lists = zip(*event_list)[0] - event_ids = [ - item for sublist in event_id_lists for item in sublist - ] - - rows = self._new_transaction( - conn, "do_fetch", [], None, self._fetch_event_rows, event_ids - ) - - row_dict = { - r["event_id"]: r - for r in rows - } - - # We only want to resolve deferreds from the main thread - def fire(lst, res): - for ids, d in lst: - if not d.called: - try: - with PreserveLoggingContext(): - d.callback([ - res[i] - for i in ids - if i in res - ]) - except: - logger.exception("Failed to callback") - with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list, row_dict) - except Exception as e: - logger.exception("do_fetch") - - # We only want to resolve deferreds from the main thread - def fire(evs): - for _, d in evs: - if not d.called: - with PreserveLoggingContext(): - d.errback(e) - - if event_list: - with PreserveLoggingContext(): - reactor.callFromThread(fire, event_list) - - @defer.inlineCallbacks - def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): - """Fetches events from the database using the _event_fetch_list. This - allows batch and bulk fetching of events - it allows us to fetch events - without having to create a new transaction for each request for events. - """ - if not events: - defer.returnValue({}) - - events_d = defer.Deferred() - with self._event_fetch_lock: - self._event_fetch_list.append( - (events, events_d) - ) - - self._event_fetch_lock.notify() - - if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: - self._event_fetch_ongoing += 1 - should_start = True - else: - should_start = False - - if should_start: - with PreserveLoggingContext(): - self.runWithConnection( - self._do_fetch - ) - - logger.debug("Loading %d events", len(events)) - with PreserveLoggingContext(): - rows = yield events_d - logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) - - if not allow_rejected: - rows[:] = [r for r in rows if not r["rejects"]] - - res = yield preserve_context_over_deferred(defer.gatherResults( - [ - preserve_fn(self._get_event_from_row)( - row["internal_metadata"], row["json"], row["redacts"], - rejected_reason=row["rejects"], - ) - for row in rows - ], - consumeErrors=True - )) - - defer.returnValue({ - e.event.event_id: e - for e in res if e - }) - - def _fetch_event_rows(self, txn, events): - rows = [] - N = 200 - for i in range(1 + len(events) / N): - evs = events[i * N:(i + 1) * N] - if not evs: - break - - sql = ( - "SELECT " - " e.event_id as event_id, " - " e.internal_metadata," - " e.json," - " r.redacts as redacts," - " rej.event_id as rejects " - " FROM event_json as e" - " LEFT JOIN rejections as rej USING (event_id)" - " LEFT JOIN redactions as r ON e.event_id = r.redacts" - " WHERE e.event_id IN (%s)" - ) % (",".join(["?"] * len(evs)),) - - txn.execute(sql, evs) - rows.extend(self.cursor_to_dict(txn)) - - return rows - - @defer.inlineCallbacks - def _get_event_from_row(self, internal_metadata, js, redacted, - rejected_reason=None): - with Measure(self._clock, "_get_event_from_row"): - d = json.loads(js) - internal_metadata = json.loads(internal_metadata) - - if rejected_reason: - rejected_reason = yield self._simple_select_one_onecol( - table="rejections", - keyvalues={"event_id": rejected_reason}, - retcol="reason", - desc="_get_event_from_row_rejected_reason", - ) - - original_ev = FrozenEvent( - d, - internal_metadata_dict=internal_metadata, - rejected_reason=rejected_reason, - ) - - redacted_event = None - if redacted: - redacted_event = prune_event(original_ev) - - redaction_id = yield self._simple_select_one_onecol( - table="redactions", - keyvalues={"redacts": redacted_event.event_id}, - retcol="event_id", - desc="_get_event_from_row_redactions", - ) - - redacted_event.unsigned["redacted_by"] = redaction_id - # Get the redaction event. - - because = yield self.get_event( - redaction_id, - check_redacted=False, - allow_none=True, - ) - - if because: - # It's fine to do add the event directly, since get_pdu_json - # will serialise this field correctly - redacted_event.unsigned["redacted_because"] = because - - cache_entry = _EventCacheEntry( - event=original_ev, - redacted_event=redacted_event, - ) - - self._get_event_cache.prefill((original_ev.event_id,), cache_entry) - - defer.returnValue(cache_entry) + return self.runInteraction("get_rejection_reasons", f) @defer.inlineCallbacks def count_daily_messages(self): @@ -1349,66 +1426,52 @@ class EventsStore(SQLBaseStore): call to this function, it will return None. """ def _count_messages(txn): - now = self.hs.get_clock().time() - - txn.execute( - "SELECT reported_stream_token, reported_time FROM stats_reporting" - ) - last_reported = self.cursor_to_dict(txn) + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + count, = txn.fetchone() + return count - txn.execute( - "SELECT stream_ordering" - " FROM events" - " ORDER BY stream_ordering DESC" - " LIMIT 1" - ) - now_reporting = self.cursor_to_dict(txn) - if not now_reporting: - logger.info("Calculating daily messages skipped; no now_reporting") - return None - now_reporting = now_reporting[0]["stream_ordering"] - - txn.execute("DELETE FROM stats_reporting") - txn.execute( - "INSERT INTO stats_reporting" - " (reported_stream_token, reported_time)" - " VALUES (?, ?)", - (now_reporting, now,) - ) - - if not last_reported: - logger.info("Calculating daily messages skipped; no last_reported") - return None - - # Close enough to correct for our purposes. - yesterday = (now - 24 * 60 * 60) - since_yesterday_seconds = yesterday - last_reported[0]["reported_time"] - any_since_yesterday = math.fabs(since_yesterday_seconds) > 60 * 60 - if any_since_yesterday: - logger.info( - "Calculating daily messages skipped; since_yesterday_seconds: %d" % - (since_yesterday_seconds,) - ) - return None + ret = yield self.runInteraction("count_messages", _count_messages) + defer.returnValue(ret) - txn.execute( - "SELECT COUNT(*) as messages" - " FROM events NATURAL JOIN event_json" - " WHERE json like '%m.room.message%'" - " AND stream_ordering > ?" - " AND stream_ordering <= ?", - ( - last_reported[0]["reported_stream_token"], - now_reporting, - ) - ) - rows = self.cursor_to_dict(txn) - if not rows: - logger.info("Calculating daily messages skipped; messages count missing") - return None - return rows[0]["messages"] + @defer.inlineCallbacks + def count_daily_sent_messages(self): + def _count_messages(txn): + # This is good enough as if you have silly characters in your own + # hostname then thats your own fault. + like_clause = "%:" + self.hs.hostname + + sql = """ + SELECT COALESCE(COUNT(*), 0) FROM events + WHERE type = 'm.room.message' + AND sender LIKE ? + AND stream_ordering > ? + """ + + txn.execute(sql, (like_clause, self.stream_ordering_day_ago,)) + count, = txn.fetchone() + return count + + ret = yield self.runInteraction("count_daily_sent_messages", _count_messages) + defer.returnValue(ret) - ret = yield self.runInteraction("count_messages", _count_messages) + @defer.inlineCallbacks + def count_daily_active_rooms(self): + def _count(txn): + sql = """ + SELECT COALESCE(COUNT(DISTINCT room_id), 0) FROM events + WHERE type = 'm.room.message' + AND stream_ordering > ? + """ + txn.execute(sql, (self.stream_ordering_day_ago,)) + count, = txn.fetchone() + return count + + ret = yield self.runInteraction("count_daily_active_rooms", _count) defer.returnValue(ret) @defer.inlineCallbacks @@ -1569,6 +1632,94 @@ class EventsStore(SQLBaseStore): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() + def get_current_events_token(self): + """The current maximum token that events have reached""" + return self._stream_id_gen.get_current_token() + + def get_all_new_forward_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_forward_event_rows(txn): + sql = ( + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < event_stream_ordering" + " AND event_stream_ordering <= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (last_id, upper_bound)) + new_event_updates.extend(txn) + + return new_event_updates + return self.runInteraction( + "get_all_new_forward_event_rows", get_all_new_forward_event_rows + ) + + def get_all_new_backfill_event_rows(self, last_id, current_id, limit): + if last_id == current_id: + return defer.succeed([]) + + def get_all_new_backfill_event_rows(txn): + sql = ( + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering ASC" + " LIMIT ?" + ) + txn.execute(sql, (-last_id, -current_id, limit)) + new_event_updates = txn.fetchall() + + if len(new_event_updates) == limit: + upper_bound = new_event_updates[-1][0] + else: + upper_bound = current_id + + sql = ( + "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " INNER JOIN ex_outlier_stream USING (event_id)" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > event_stream_ordering" + " AND event_stream_ordering >= ?" + " ORDER BY event_stream_ordering DESC" + ) + txn.execute(sql, (-last_id, -upper_bound)) + new_event_updates.extend(txn.fetchall()) + + return new_event_updates + return self.runInteraction( + "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows + ) + @cached(num_args=5, max_entries=10) def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): @@ -1582,14 +1733,13 @@ class EventsStore(SQLBaseStore): def get_all_new_events_txn(txn): sql = ( - "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group" - " FROM events as e" - " JOIN event_json as ej" - " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" - " LEFT JOIN event_to_state_groups as eg" - " ON e.event_id = eg.event_id" - " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" " LIMIT ?" ) if have_forward_events: @@ -1615,15 +1765,13 @@ class EventsStore(SQLBaseStore): forward_ex_outliers = [] sql = ( - "SELECT -e.stream_ordering, ej.internal_metadata, ej.json," - " eg.state_group" - " FROM events as e" - " JOIN event_json as ej" - " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" - " LEFT JOIN event_to_state_groups as eg" - " ON e.event_id = eg.event_id" - " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" - " ORDER BY e.stream_ordering DESC" + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering DESC" " LIMIT ?" ) if have_backfill_events: @@ -1654,16 +1802,32 @@ class EventsStore(SQLBaseStore): ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) - def delete_old_state(self, room_id, topological_ordering): - return self.runInteraction( - "delete_old_state", - self._delete_old_state_txn, room_id, topological_ordering - ) + def purge_history( + self, room_id, topological_ordering, delete_local_events, + ): + """Deletes room history before a certain point + + Args: + room_id (str): - def _delete_old_state_txn(self, txn, room_id, topological_ordering): - """Deletes old room state + topological_ordering (int): + minimum topo ordering to preserve + + delete_local_events (bool): + if True, we will delete local events as well as remote ones + (instead of just marking them as outliers and deleting their + state groups). """ + return self.runInteraction( + "purge_history", + self._purge_history_txn, room_id, topological_ordering, + delete_local_events, + ) + + def _purge_history_txn( + self, txn, room_id, topological_ordering, delete_local_events, + ): # Tables that should be pruned: # event_auth # event_backward_extremities @@ -1684,6 +1848,30 @@ class EventsStore(SQLBaseStore): # state_groups # state_groups_state + # we will build a temporary table listing the events so that we don't + # have to keep shovelling the list back and forth across the + # connection. Annoyingly the python sqlite driver commits the + # transaction on CREATE, so let's do this first. + # + # furthermore, we might already have the table from a previous (failed) + # purge attempt, so let's drop the table first. + + txn.execute("DROP TABLE IF EXISTS events_to_purge") + + txn.execute( + "CREATE TEMPORARY TABLE events_to_purge (" + " event_id TEXT NOT NULL," + " should_delete BOOLEAN NOT NULL" + ")" + ) + + # create an index on should_delete because later we'll be looking for + # the should_delete / shouldn't_delete subsets + txn.execute( + "CREATE INDEX events_to_purge_should_delete" + " ON events_to_purge(should_delete)", + ) + # First ensure that we're not about to delete all the forward extremeties txn.execute( "SELECT e.event_id, e.depth FROM events as e " @@ -1704,29 +1892,49 @@ class EventsStore(SQLBaseStore): 400, "topological_ordering is greater than forward extremeties" ) + logger.info("[purge] looking for events to delete") + + should_delete_expr = "state_key IS NULL" + should_delete_params = () + if not delete_local_events: + should_delete_expr += " AND event_id NOT LIKE ?" + should_delete_params += ("%:" + self.hs.hostname, ) + + should_delete_params += (room_id, topological_ordering) + + txn.execute( + "INSERT INTO events_to_purge" + " SELECT event_id, %s" + " FROM events AS e LEFT JOIN state_events USING (event_id)" + " WHERE e.room_id = ? AND topological_ordering < ?" % ( + should_delete_expr, + ), + should_delete_params, + ) txn.execute( - "SELECT event_id, state_key FROM events" - " LEFT JOIN state_events USING (room_id, event_id)" - " WHERE room_id = ? AND topological_ordering < ?", - (room_id, topological_ordering,) + "SELECT event_id, should_delete FROM events_to_purge" ) event_rows = txn.fetchall() + logger.info( + "[purge] found %i events before cutoff, of which %i can be deleted", + len(event_rows), sum(1 for e in event_rows if e[1]), + ) - for event_id, state_key in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) + logger.info("[purge] Finding new backward extremities") # We calculate the new entries for the backward extremeties by finding # all events that point to events that are to be purged txn.execute( - "SELECT DISTINCT e.event_id FROM events as e" - " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id" - " INNER JOIN events as e2 ON e2.event_id = ed.event_id" - " WHERE e.room_id = ? AND e.topological_ordering < ?" - " AND e2.topological_ordering >= ?", - (room_id, topological_ordering, topological_ordering) + "SELECT DISTINCT e.event_id FROM events_to_purge AS e" + " INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id" + " INNER JOIN events AS e2 ON e2.event_id = ed.event_id" + " WHERE e2.topological_ordering >= ?", + (topological_ordering, ) ) new_backwards_extrems = txn.fetchall() + logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems) + txn.execute( "DELETE FROM event_backward_extremities WHERE room_id = ?", (room_id,) @@ -1741,30 +1949,36 @@ class EventsStore(SQLBaseStore): ] ) + logger.info("[purge] finding redundant state groups") + # Get all state groups that are only referenced by events that are # to be deleted. txn.execute( "SELECT state_group FROM event_to_state_groups" " INNER JOIN events USING (event_id)" " WHERE state_group IN (" - " SELECT DISTINCT state_group FROM events" + " SELECT DISTINCT state_group FROM events_to_purge" " INNER JOIN event_to_state_groups USING (event_id)" - " WHERE room_id = ? AND topological_ordering < ?" " )" " GROUP BY state_group HAVING MAX(topological_ordering) < ?", - (room_id, topological_ordering, topological_ordering) + (topological_ordering, ) ) state_rows = txn.fetchall() - state_groups_to_delete = [sg for sg, in state_rows] + logger.info("[purge] found %i redundant state groups", len(state_rows)) + + # make a set of the redundant state groups, so that we can look them up + # efficiently + state_groups_to_delete = set([sg for sg, in state_rows]) # Now we get all the state groups that rely on these state groups - new_state_edges = [] - chunks = [ - state_groups_to_delete[i:i + 100] - for i in xrange(0, len(state_groups_to_delete), 100) - ] - for chunk in chunks: + logger.info("[purge] finding state groups which depend on redundant" + " state groups") + remaining_state_groups = [] + for i in xrange(0, len(state_rows), 100): + chunk = [sg for sg, in state_rows[i:i + 100]] + # look for state groups whose prev_state_group is one we are about + # to delete rows = self._simple_select_many_txn( txn, table="state_group_edges", @@ -1773,21 +1987,28 @@ class EventsStore(SQLBaseStore): retcols=["state_group"], keyvalues={}, ) - new_state_edges.extend(row["state_group"] for row in rows) + remaining_state_groups.extend( + row["state_group"] for row in rows + + # exclude state groups we are about to delete: no point in + # updating them + if row["state_group"] not in state_groups_to_delete + ) - # Now we turn the state groups that reference to-be-deleted state groups - # to non delta versions. - for new_state_edge in new_state_edges: + # Now we turn the state groups that reference to-be-deleted state + # groups to non delta versions. + for sg in remaining_state_groups: + logger.info("[purge] de-delta-ing remaining state group %s", sg) curr_state = self._get_state_groups_from_groups_txn( - txn, [new_state_edge], types=None + txn, [sg], types=None ) - curr_state = curr_state[new_state_edge] + curr_state = curr_state[sg] self._simple_delete_txn( txn, table="state_groups_state", keyvalues={ - "state_group": new_state_edge, + "state_group": sg, } ) @@ -1795,7 +2016,7 @@ class EventsStore(SQLBaseStore): txn, table="state_group_edges", keyvalues={ - "state_group": new_state_edge, + "state_group": sg, } ) @@ -1804,16 +2025,17 @@ class EventsStore(SQLBaseStore): table="state_groups_state", values=[ { - "state_group": new_state_edge, + "state_group": sg, "room_id": room_id, "type": key[0], "state_key": key[1], "event_id": state_id, } - for key, state_id in curr_state.items() + for key, state_id in curr_state.iteritems() ], ) + logger.info("[purge] removing redundant state groups") txn.executemany( "DELETE FROM state_groups_state WHERE state_group = ?", state_rows @@ -1822,22 +2044,18 @@ class EventsStore(SQLBaseStore): "DELETE FROM state_groups WHERE id = ?", state_rows ) - # Delete all non-state - txn.executemany( - "DELETE FROM event_to_state_groups WHERE event_id = ?", - [(event_id,) for event_id, _ in event_rows] - ) + logger.info("[purge] removing events from event_to_state_groups") txn.execute( - "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", - (topological_ordering, room_id,) + "DELETE FROM event_to_state_groups " + "WHERE event_id IN (SELECT event_id from events_to_purge)" ) + for event_id, _ in event_rows: + txn.call_after(self._get_state_group_for_event.invalidate, ( + event_id, + )) # Delete all remote non-state events - to_delete = [ - (event_id,) for event_id, state_key in event_rows - if state_key is None and not self.hs.is_mine_id(event_id) - ] for table in ( "events", "event_json", @@ -1847,29 +2065,102 @@ class EventsStore(SQLBaseStore): "event_edge_hashes", "event_edges", "event_forward_extremities", - "event_push_actions", "event_reference_hashes", "event_search", "event_signatures", "rejections", ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - to_delete + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + ) + + # event_push_actions lacks an index on event_id, and has one on + # (room_id, event_id) instead. + for table in ( + "event_push_actions", + ): + logger.info("[purge] removing events from %s", table) + + txn.execute( + "DELETE FROM %s WHERE room_id = ? AND event_id IN (" + " SELECT event_id FROM events_to_purge WHERE should_delete" + ")" % (table,), + (room_id, ) ) - txn.executemany( - "DELETE FROM events WHERE event_id = ?", - to_delete - ) # Mark all state and own events as outliers - txn.executemany( + logger.info("[purge] marking remaining events as outliers") + txn.execute( "UPDATE events SET outlier = ?" - " WHERE event_id = ?", - [ - (True, event_id,) for event_id, state_key in event_rows - if state_key is not None or self.hs.is_mine_id(event_id) - ] + " WHERE event_id IN (" + " SELECT event_id FROM events_to_purge " + " WHERE NOT should_delete" + ")", + (True,), + ) + + # synapse tries to take out an exclusive lock on room_depth whenever it + # persists events (because upsert), and once we run this update, we + # will block that for the rest of our transaction. + # + # So, let's stick it at the end so that we don't block event + # persistence. + logger.info("[purge] updating room_depth") + txn.execute( + "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", + (topological_ordering, room_id,) + ) + + # finally, drop the temp table. this will commit the txn in sqlite, + # so make sure to keep this actually last. + txn.execute( + "DROP TABLE events_to_purge" + ) + + logger.info("[purge] done") + + @defer.inlineCallbacks + def is_event_after(self, event_id1, event_id2): + """Returns True if event_id1 is after event_id2 in the stream + """ + to_1, so_1 = yield self._get_event_ordering(event_id1) + to_2, so_2 = yield self._get_event_ordering(event_id2) + defer.returnValue((to_1, so_1) > (to_2, so_2)) + + @cachedInlineCallbacks(max_entries=5000) + def _get_event_ordering(self, event_id): + res = yield self._simple_select_one( + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + allow_none=True + ) + + if not res: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + defer.returnValue((int(res["topological_ordering"]), int(res["stream_ordering"]))) + + def get_max_current_state_delta_stream_id(self): + return self._stream_id_gen.get_current_token() + + def get_all_updated_current_state_deltas(self, from_token, to_token, limit): + def get_all_updated_current_state_deltas_txn(txn): + sql = """ + SELECT stream_id, room_id, type, state_key, event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit)) + return txn.fetchall() + return self.runInteraction( + "get_all_updated_current_state_deltas", + get_all_updated_current_state_deltas_txn, ) diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py new file mode 100644 index 0000000000..ba834854e1 --- /dev/null +++ b/synapse/storage/events_worker.py @@ -0,0 +1,416 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore + +from twisted.internet import defer, reactor + +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from synapse.util.logcontext import ( + PreserveLoggingContext, make_deferred_yieldable, run_in_background, +) +from synapse.util.metrics import Measure +from synapse.api.errors import SynapseError + +from collections import namedtuple + +import logging +import simplejson as json + +# these are only included to make the type annotations work +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 + +logger = logging.getLogger(__name__) + + +# These values are used in the `enqueus_event` and `_do_fetch` methods to +# control how we batch/bulk fetch events from the database. +# The values are plucked out of thing air to make initial sync run faster +# on jki.re +# TODO: Make these configurable. +EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events +EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events +EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events + + +_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) + + +class EventsWorkerStore(SQLBaseStore): + def get_received_ts(self, event_id): + """Get received_ts (when it was persisted) for the event. + + Raises an exception for unknown events. + + Args: + event_id (str) + + Returns: + Deferred[int|None]: Timestamp in milliseconds, or None for events + that were persisted before received_ts was implemented. + """ + return self._simple_select_one_onecol( + table="events", + keyvalues={ + "event_id": event_id, + }, + retcol="received_ts", + desc="get_received_ts", + ) + + @defer.inlineCallbacks + def get_event(self, event_id, check_redacted=True, + get_prev_content=False, allow_rejected=False, + allow_none=False): + """Get an event from the database by event_id. + + Args: + event_id (str): The event_id of the event to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + allow_none (bool): If True, return None if no event found, if + False throw an exception. + + Returns: + Deferred : A FrozenEvent. + """ + events = yield self._get_events( + [event_id], + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + if not events and not allow_none: + raise SynapseError(404, "Could not find event %s" % (event_id,)) + + defer.returnValue(events[0] if events else None) + + @defer.inlineCallbacks + def get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + """Get events from the database + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred : Dict from event_id to event. + """ + events = yield self._get_events( + event_ids, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + allow_rejected=allow_rejected, + ) + + defer.returnValue({e.event_id: e for e in events}) + + @defer.inlineCallbacks + def _get_events(self, event_ids, check_redacted=True, + get_prev_content=False, allow_rejected=False): + if not event_ids: + defer.returnValue([]) + + event_id_list = event_ids + event_ids = set(event_ids) + + event_entry_map = self._get_events_from_cache( + event_ids, + allow_rejected=allow_rejected, + ) + + missing_events_ids = [e for e in event_ids if e not in event_entry_map] + + if missing_events_ids: + missing_events = yield self._enqueue_events( + missing_events_ids, + check_redacted=check_redacted, + allow_rejected=allow_rejected, + ) + + event_entry_map.update(missing_events) + + events = [] + for event_id in event_id_list: + entry = event_entry_map.get(event_id, None) + if not entry: + continue + + if allow_rejected or not entry.event.rejected_reason: + if check_redacted and entry.redacted_event: + event = entry.redacted_event + else: + event = entry.event + + events.append(event) + + if get_prev_content: + if "replaces_state" in event.unsigned: + prev = yield self.get_event( + event.unsigned["replaces_state"], + get_prev_content=False, + allow_none=True, + ) + if prev: + event.unsigned = dict(event.unsigned) + event.unsigned["prev_content"] = prev.content + event.unsigned["prev_sender"] = prev.sender + + defer.returnValue(events) + + def _invalidate_get_event_cache(self, event_id): + self._get_event_cache.invalidate((event_id,)) + + def _get_events_from_cache(self, events, allow_rejected, update_metrics=True): + """Fetch events from the caches + + Args: + events (list(str)): list of event_ids to fetch + allow_rejected (bool): Whether to teturn events that were rejected + update_metrics (bool): Whether to update the cache hit ratio metrics + + Returns: + dict of event_id -> _EventCacheEntry for each event_id in cache. If + allow_rejected is `False` then there will still be an entry but it + will be `None` + """ + event_map = {} + + for event_id in events: + ret = self._get_event_cache.get( + (event_id,), None, + update_metrics=update_metrics, + ) + if not ret: + continue + + if allow_rejected or not ret.event.rejected_reason: + event_map[event_id] = ret + else: + event_map[event_id] = None + + return event_map + + def _do_fetch(self, conn): + """Takes a database connection and waits for requests for events from + the _event_fetch_list queue. + """ + event_list = [] + i = 0 + while True: + try: + with self._event_fetch_lock: + event_list = self._event_fetch_list + self._event_fetch_list = [] + + if not event_list: + single_threaded = self.database_engine.single_threaded + if single_threaded or i > EVENT_QUEUE_ITERATIONS: + self._event_fetch_ongoing -= 1 + return + else: + self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S) + i += 1 + continue + i = 0 + + event_id_lists = zip(*event_list)[0] + event_ids = [ + item for sublist in event_id_lists for item in sublist + ] + + rows = self._new_transaction( + conn, "do_fetch", [], [], None, self._fetch_event_rows, event_ids + ) + + row_dict = { + r["event_id"]: r + for r in rows + } + + # We only want to resolve deferreds from the main thread + def fire(lst, res): + for ids, d in lst: + if not d.called: + try: + with PreserveLoggingContext(): + d.callback([ + res[i] + for i in ids + if i in res + ]) + except Exception: + logger.exception("Failed to callback") + with PreserveLoggingContext(): + reactor.callFromThread(fire, event_list, row_dict) + except Exception as e: + logger.exception("do_fetch") + + # We only want to resolve deferreds from the main thread + def fire(evs): + for _, d in evs: + if not d.called: + with PreserveLoggingContext(): + d.errback(e) + + if event_list: + with PreserveLoggingContext(): + reactor.callFromThread(fire, event_list) + + @defer.inlineCallbacks + def _enqueue_events(self, events, check_redacted=True, allow_rejected=False): + """Fetches events from the database using the _event_fetch_list. This + allows batch and bulk fetching of events - it allows us to fetch events + without having to create a new transaction for each request for events. + """ + if not events: + defer.returnValue({}) + + events_d = defer.Deferred() + with self._event_fetch_lock: + self._event_fetch_list.append( + (events, events_d) + ) + + self._event_fetch_lock.notify() + + if self._event_fetch_ongoing < EVENT_QUEUE_THREADS: + self._event_fetch_ongoing += 1 + should_start = True + else: + should_start = False + + if should_start: + with PreserveLoggingContext(): + self.runWithConnection( + self._do_fetch + ) + + logger.debug("Loading %d events", len(events)) + with PreserveLoggingContext(): + rows = yield events_d + logger.debug("Loaded %d events (%d rows)", len(events), len(rows)) + + if not allow_rejected: + rows[:] = [r for r in rows if not r["rejects"]] + + res = yield make_deferred_yieldable(defer.gatherResults( + [ + run_in_background( + self._get_event_from_row, + row["internal_metadata"], row["json"], row["redacts"], + rejected_reason=row["rejects"], + ) + for row in rows + ], + consumeErrors=True + )) + + defer.returnValue({ + e.event.event_id: e + for e in res if e + }) + + def _fetch_event_rows(self, txn, events): + rows = [] + N = 200 + for i in range(1 + len(events) / N): + evs = events[i * N:(i + 1) * N] + if not evs: + break + + sql = ( + "SELECT " + " e.event_id as event_id, " + " e.internal_metadata," + " e.json," + " r.redacts as redacts," + " rej.event_id as rejects " + " FROM event_json as e" + " LEFT JOIN rejections as rej USING (event_id)" + " LEFT JOIN redactions as r ON e.event_id = r.redacts" + " WHERE e.event_id IN (%s)" + ) % (",".join(["?"] * len(evs)),) + + txn.execute(sql, evs) + rows.extend(self.cursor_to_dict(txn)) + + return rows + + @defer.inlineCallbacks + def _get_event_from_row(self, internal_metadata, js, redacted, + rejected_reason=None): + with Measure(self._clock, "_get_event_from_row"): + d = json.loads(js) + internal_metadata = json.loads(internal_metadata) + + if rejected_reason: + rejected_reason = yield self._simple_select_one_onecol( + table="rejections", + keyvalues={"event_id": rejected_reason}, + retcol="reason", + desc="_get_event_from_row_rejected_reason", + ) + + original_ev = FrozenEvent( + d, + internal_metadata_dict=internal_metadata, + rejected_reason=rejected_reason, + ) + + redacted_event = None + if redacted: + redacted_event = prune_event(original_ev) + + redaction_id = yield self._simple_select_one_onecol( + table="redactions", + keyvalues={"redacts": redacted_event.event_id}, + retcol="event_id", + desc="_get_event_from_row_redactions", + ) + + redacted_event.unsigned["redacted_by"] = redaction_id + # Get the redaction event. + + because = yield self.get_event( + redaction_id, + check_redacted=False, + allow_none=True, + ) + + if because: + # It's fine to do add the event directly, since get_pdu_json + # will serialise this field correctly + redacted_event.unsigned["redacted_because"] = because + + cache_entry = _EventCacheEntry( + event=original_ev, + redacted_event=redacted_event, + ) + + self._get_event_cache.prefill((original_ev.event_id,), cache_entry) + + defer.returnValue(cache_entry) diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py index a2ccc66ea7..78b1e30945 100644 --- a/synapse/storage/filtering.py +++ b/synapse/storage/filtering.py @@ -19,6 +19,7 @@ from ._base import SQLBaseStore from synapse.api.errors import SynapseError, Codes from synapse.util.caches.descriptors import cachedInlineCallbacks +from canonicaljson import encode_canonical_json import simplejson as json @@ -46,12 +47,21 @@ class FilteringStore(SQLBaseStore): defer.returnValue(json.loads(str(def_json).decode("utf-8"))) def add_user_filter(self, user_localpart, user_filter): - def_json = json.dumps(user_filter).encode("utf-8") + def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then # INSERT a new one def _do_txn(txn): sql = ( + "SELECT filter_id FROM user_filters " + "WHERE user_id = ? AND filter_json = ?" + ) + txn.execute(sql, (user_localpart, def_json)) + filter_id_response = txn.fetchone() + if filter_id_response is not None: + return filter_id_response[0] + + sql = ( "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?" ) diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py new file mode 100644 index 0000000000..da05ccb027 --- /dev/null +++ b/synapse/storage/group_server.py @@ -0,0 +1,1253 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError + +from ._base import SQLBaseStore + +import simplejson as json + + +# The category ID for the "default" category. We don't store as null in the +# database to avoid the fun of null != null +_DEFAULT_CATEGORY_ID = "" +_DEFAULT_ROLE_ID = "" + + +class GroupServerStore(SQLBaseStore): + def set_group_join_policy(self, group_id, join_policy): + """Set the join policy of a group. + + join_policy can be one of: + * "invite" + * "open" + """ + return self._simple_update_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + updatevalues={ + "join_policy": join_policy, + }, + desc="set_group_join_policy", + ) + + def get_group(self, group_id): + return self._simple_select_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + retcols=( + "name", "short_description", "long_description", + "avatar_url", "is_public", "join_policy", + ), + allow_none=True, + desc="get_group", + ) + + def get_users_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + return self._simple_select_list( + table="group_users", + keyvalues=keyvalues, + retcols=("user_id", "is_public", "is_admin",), + desc="get_users_in_group", + ) + + def get_invited_users_in_group(self, group_id): + # TODO: Pagination + + return self._simple_select_onecol( + table="group_invites", + keyvalues={ + "group_id": group_id, + }, + retcol="user_id", + desc="get_invited_users_in_group", + ) + + def get_rooms_in_group(self, group_id, include_private=False): + # TODO: Pagination + + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + return self._simple_select_list( + table="group_rooms", + keyvalues=keyvalues, + retcols=("room_id", "is_public",), + desc="get_rooms_in_group", + ) + + def get_rooms_for_summary_by_category(self, group_id, include_private=False): + """Get the rooms and categories that should be included in a summary request + + Returns ([rooms], [categories]) + """ + def _get_rooms_for_summary_txn(txn): + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT room_id, is_public, category_id, room_order + FROM group_summary_rooms + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + rooms = [ + { + "room_id": row[0], + "is_public": row[1], + "category_id": row[2] if row[2] != _DEFAULT_CATEGORY_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT category_id, is_public, profile, cat_order + FROM group_summary_room_categories + INNER JOIN group_room_categories USING (group_id, category_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + categories = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return rooms, categories + return self.runInteraction( + "get_rooms_for_summary", _get_rooms_for_summary_txn + ) + + def add_room_to_summary(self, group_id, room_id, category_id, order, is_public): + return self.runInteraction( + "add_room_to_summary", self._add_room_to_summary_txn, + group_id, room_id, category_id, order, is_public, + ) + + def _add_room_to_summary_txn(self, txn, group_id, room_id, category_id, order, + is_public): + """Add (or update) room's entry in summary. + + Args: + group_id (str) + room_id (str) + category_id (str): If not None then adds the category to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the room at that position, e.g. + an order of 1 will put the room first. Otherwise, the room gets + added to the end. + """ + room_in_group = self._simple_select_one_onecol_txn( + txn, + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + retcol="room_id", + allow_none=True, + ) + if not room_in_group: + raise SynapseError(400, "room not in group") + + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + else: + cat_exists = self._simple_select_one_onecol_txn( + txn, + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + raise SynapseError(400, "Category doesn't exist") + + # TODO: Check category is part of summary already + cat_exists = self._simple_select_one_onecol_txn( + txn, + table="group_summary_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcol="group_id", + allow_none=True, + ) + if not cat_exists: + # If not, add it with an order larger than all others + txn.execute(""" + INSERT INTO group_summary_room_categories + (group_id, category_id, cat_order) + SELECT ?, ?, COALESCE(MAX(cat_order), 0) + 1 + FROM group_summary_room_categories + WHERE group_id = ? AND category_id = ? + """, (group_id, category_id, group_id, category_id)) + + existing = self._simple_select_one_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + "category_id": category_id, + }, + retcols=("room_order", "is_public",), + allow_none=True, + ) + + if order is not None: + # Shuffle other room orders that come after the given order + sql = """ + UPDATE group_summary_rooms SET room_order = room_order + 1 + WHERE group_id = ? AND category_id = ? AND room_order >= ? + """ + txn.execute(sql, (group_id, category_id, order,)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(room_order), 0) + 1 FROM group_summary_rooms + WHERE group_id = ? AND category_id = ? + """ + txn.execute(sql, (group_id, category_id,)) + order, = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["room_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self._simple_update_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self._simple_insert_txn( + txn, + table="group_summary_rooms", + values={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + "room_order": order, + "is_public": is_public, + }, + ) + + def remove_room_from_summary(self, group_id, room_id, category_id): + if category_id is None: + category_id = _DEFAULT_CATEGORY_ID + + return self._simple_delete( + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + "room_id": room_id, + }, + desc="remove_room_from_summary", + ) + + @defer.inlineCallbacks + def get_group_categories(self, group_id): + rows = yield self._simple_select_list( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + }, + retcols=("category_id", "is_public", "profile"), + desc="get_group_categories", + ) + + defer.returnValue({ + row["category_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + }) + + @defer.inlineCallbacks + def get_group_category(self, group_id, category_id): + category = yield self._simple_select_one( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + retcols=("is_public", "profile"), + desc="get_group_category", + ) + + category["profile"] = json.loads(category["profile"]) + + defer.returnValue(category) + + def upsert_group_category(self, group_id, category_id, profile, is_public): + """Add/update room category for group + """ + insertion_values = {} + update_values = {"category_id": category_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self._simple_upsert( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_category", + ) + + def remove_group_category(self, group_id, category_id): + return self._simple_delete( + table="group_room_categories", + keyvalues={ + "group_id": group_id, + "category_id": category_id, + }, + desc="remove_group_category", + ) + + @defer.inlineCallbacks + def get_group_roles(self, group_id): + rows = yield self._simple_select_list( + table="group_roles", + keyvalues={ + "group_id": group_id, + }, + retcols=("role_id", "is_public", "profile"), + desc="get_group_roles", + ) + + defer.returnValue({ + row["role_id"]: { + "is_public": row["is_public"], + "profile": json.loads(row["profile"]), + } + for row in rows + }) + + @defer.inlineCallbacks + def get_group_role(self, group_id, role_id): + role = yield self._simple_select_one( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcols=("is_public", "profile"), + desc="get_group_role", + ) + + role["profile"] = json.loads(role["profile"]) + + defer.returnValue(role) + + def upsert_group_role(self, group_id, role_id, profile, is_public): + """Add/remove user role + """ + insertion_values = {} + update_values = {"role_id": role_id} # This cannot be empty + + if profile is None: + insertion_values["profile"] = "{}" + else: + update_values["profile"] = json.dumps(profile) + + if is_public is None: + insertion_values["is_public"] = True + else: + update_values["is_public"] = is_public + + return self._simple_upsert( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + values=update_values, + insertion_values=insertion_values, + desc="upsert_group_role", + ) + + def remove_group_role(self, group_id, role_id): + return self._simple_delete( + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + desc="remove_group_role", + ) + + def add_user_to_summary(self, group_id, user_id, role_id, order, is_public): + return self.runInteraction( + "add_user_to_summary", self._add_user_to_summary_txn, + group_id, user_id, role_id, order, is_public, + ) + + def _add_user_to_summary_txn(self, txn, group_id, user_id, role_id, order, + is_public): + """Add (or update) user's entry in summary. + + Args: + group_id (str) + user_id (str) + role_id (str): If not None then adds the role to the end of + the summary if its not already there. [Optional] + order (int): If not None inserts the user at that position, e.g. + an order of 1 will put the user first. Otherwise, the user gets + added to the end. + """ + user_in_group = self._simple_select_one_onecol_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + ) + if not user_in_group: + raise SynapseError(400, "user not in group") + + if role_id is None: + role_id = _DEFAULT_ROLE_ID + else: + role_exists = self._simple_select_one_onecol_txn( + txn, + table="group_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + raise SynapseError(400, "Role doesn't exist") + + # TODO: Check role is part of the summary already + role_exists = self._simple_select_one_onecol_txn( + txn, + table="group_summary_roles", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + }, + retcol="group_id", + allow_none=True, + ) + if not role_exists: + # If not, add it with an order larger than all others + txn.execute(""" + INSERT INTO group_summary_roles + (group_id, role_id, role_order) + SELECT ?, ?, COALESCE(MAX(role_order), 0) + 1 + FROM group_summary_roles + WHERE group_id = ? AND role_id = ? + """, (group_id, role_id, group_id, role_id)) + + existing = self._simple_select_one_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + "role_id": role_id, + }, + retcols=("user_order", "is_public",), + allow_none=True, + ) + + if order is not None: + # Shuffle other users orders that come after the given order + sql = """ + UPDATE group_summary_users SET user_order = user_order + 1 + WHERE group_id = ? AND role_id = ? AND user_order >= ? + """ + txn.execute(sql, (group_id, role_id, order,)) + elif not existing: + sql = """ + SELECT COALESCE(MAX(user_order), 0) + 1 FROM group_summary_users + WHERE group_id = ? AND role_id = ? + """ + txn.execute(sql, (group_id, role_id,)) + order, = txn.fetchone() + + if existing: + to_update = {} + if order is not None: + to_update["user_order"] = order + if is_public is not None: + to_update["is_public"] = is_public + self._simple_update_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + values=to_update, + ) + else: + if is_public is None: + is_public = True + + self._simple_insert_txn( + txn, + table="group_summary_users", + values={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + "user_order": order, + "is_public": is_public, + }, + ) + + def remove_user_from_summary(self, group_id, user_id, role_id): + if role_id is None: + role_id = _DEFAULT_ROLE_ID + + return self._simple_delete( + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "role_id": role_id, + "user_id": user_id, + }, + desc="remove_user_from_summary", + ) + + def get_users_for_summary_by_role(self, group_id, include_private=False): + """Get the users and roles that should be included in a summary request + + Returns ([users], [roles]) + """ + def _get_users_for_summary_txn(txn): + keyvalues = { + "group_id": group_id, + } + if not include_private: + keyvalues["is_public"] = True + + sql = """ + SELECT user_id, is_public, role_id, user_order + FROM group_summary_users + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + users = [ + { + "user_id": row[0], + "is_public": row[1], + "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None, + "order": row[3], + } + for row in txn + ] + + sql = """ + SELECT role_id, is_public, profile, role_order + FROM group_summary_roles + INNER JOIN group_roles USING (group_id, role_id) + WHERE group_id = ? + """ + + if not include_private: + sql += " AND is_public = ?" + txn.execute(sql, (group_id, True)) + else: + txn.execute(sql, (group_id,)) + + roles = { + row[0]: { + "is_public": row[1], + "profile": json.loads(row[2]), + "order": row[3], + } + for row in txn + } + + return users, roles + return self.runInteraction( + "get_users_for_summary_by_role", _get_users_for_summary_txn + ) + + def is_user_in_group(self, user_id, group_id): + return self._simple_select_one_onecol( + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + desc="is_user_in_group", + ).addCallback(lambda r: bool(r)) + + def is_user_admin_in_group(self, group_id, user_id): + return self._simple_select_one_onecol( + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="is_admin", + allow_none=True, + desc="is_user_admin_in_group", + ) + + def add_group_invite(self, group_id, user_id): + """Record that the group server has invited a user + """ + return self._simple_insert( + table="group_invites", + values={ + "group_id": group_id, + "user_id": user_id, + }, + desc="add_group_invite", + ) + + def is_user_invited_to_local_group(self, group_id, user_id): + """Has the group server invited a user? + """ + return self._simple_select_one_onecol( + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + desc="is_user_invited_to_local_group", + allow_none=True, + ) + + def get_users_membership_info_in_group(self, group_id, user_id): + """Get a dict describing the membership of a user in a group. + + Example if joined: + + { + "membership": "join", + "is_public": True, + "is_privileged": False, + } + + Returns an empty dict if the user is not join/invite/etc + """ + def _get_users_membership_in_group_txn(txn): + row = self._simple_select_one_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcols=("is_admin", "is_public"), + allow_none=True, + ) + + if row: + return { + "membership": "join", + "is_public": row["is_public"], + "is_privileged": row["is_admin"], + } + + row = self._simple_select_one_onecol_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcol="user_id", + allow_none=True, + ) + + if row: + return { + "membership": "invite", + } + + return {} + + return self.runInteraction( + "get_users_membership_info_in_group", _get_users_membership_in_group_txn, + ) + + def add_user_to_group(self, group_id, user_id, is_admin=False, is_public=True, + local_attestation=None, remote_attestation=None): + """Add a user to the group server. + + Args: + group_id (str) + user_id (str) + is_admin (bool) + is_public (bool) + local_attestation (dict): The attestation the GS created to give + to the remote server. Optional if the user and group are on the + same server + remote_attestation (dict): The attestation given to GS by remote + server. Optional if the user and group are on the same server + """ + def _add_user_to_group_txn(txn): + self._simple_insert_txn( + txn, + table="group_users", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "is_public": is_public, + }, + ) + + self._simple_delete_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + + if local_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + }, + ) + if remote_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + }, + ) + + return self.runInteraction( + "add_user_to_group", _add_user_to_group_txn + ) + + def remove_user_from_group(self, group_id, user_id): + def _remove_user_from_group_txn(txn): + self._simple_delete_txn( + txn, + table="group_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_invites", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_summary_users", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + return self.runInteraction("remove_user_from_group", _remove_user_from_group_txn) + + def add_room_to_group(self, group_id, room_id, is_public): + return self._simple_insert( + table="group_rooms", + values={ + "group_id": group_id, + "room_id": room_id, + "is_public": is_public, + }, + desc="add_room_to_group", + ) + + def update_room_in_group_visibility(self, group_id, room_id, is_public): + return self._simple_update( + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + updatevalues={ + "is_public": is_public, + }, + desc="update_room_in_group_visibility", + ) + + def remove_room_from_group(self, group_id, room_id): + def _remove_room_from_group_txn(txn): + self._simple_delete_txn( + txn, + table="group_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + ) + + self._simple_delete_txn( + txn, + table="group_summary_rooms", + keyvalues={ + "group_id": group_id, + "room_id": room_id, + }, + ) + return self.runInteraction( + "remove_room_from_group", _remove_room_from_group_txn, + ) + + def get_publicised_groups_for_user(self, user_id): + """Get all groups a user is publicising + """ + return self._simple_select_onecol( + table="local_group_membership", + keyvalues={ + "user_id": user_id, + "membership": "join", + "is_publicised": True, + }, + retcol="group_id", + desc="get_publicised_groups_for_user", + ) + + def update_group_publicity(self, group_id, user_id, publicise): + """Update whether the user is publicising their membership of the group + """ + return self._simple_update_one( + table="local_group_membership", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "is_publicised": publicise, + }, + desc="update_group_publicity" + ) + + @defer.inlineCallbacks + def register_user_group_membership(self, group_id, user_id, membership, + is_admin=False, content={}, + local_attestation=None, + remote_attestation=None, + is_publicised=False, + ): + """Registers that a local user is a member of a (local or remote) group. + + Args: + group_id (str) + user_id (str) + membership (str) + is_admin (bool) + content (dict): Content of the membership, e.g. includes the inviter + if the user has been invited. + local_attestation (dict): If remote group then store the fact that we + have given out an attestation, else None. + remote_attestation (dict): If remote group then store the remote + attestation from the group, else None. + """ + def _register_user_group_membership_txn(txn, next_id): + # TODO: Upsert? + self._simple_delete_txn( + txn, + table="local_group_membership", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_insert_txn( + txn, + table="local_group_membership", + values={ + "group_id": group_id, + "user_id": user_id, + "is_admin": is_admin, + "membership": membership, + "is_publicised": is_publicised, + "content": json.dumps(content), + }, + ) + + self._simple_insert_txn( + txn, + table="local_group_updates", + values={ + "stream_id": next_id, + "group_id": group_id, + "user_id": user_id, + "type": "membership", + "content": json.dumps({"membership": membership, "content": content}), + } + ) + self._group_updates_stream_cache.entity_has_changed(user_id, next_id) + + # TODO: Insert profile to ensure it comes down stream if its a join. + + if membership == "join": + if local_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_renewals", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": local_attestation["valid_until_ms"], + } + ) + if remote_attestation: + self._simple_insert_txn( + txn, + table="group_attestations_remote", + values={ + "group_id": group_id, + "user_id": user_id, + "valid_until_ms": remote_attestation["valid_until_ms"], + "attestation_json": json.dumps(remote_attestation), + } + ) + else: + self._simple_delete_txn( + txn, + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + self._simple_delete_txn( + txn, + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + ) + + return next_id + + with self._group_updates_id_gen.get_next() as next_id: + res = yield self.runInteraction( + "register_user_group_membership", + _register_user_group_membership_txn, next_id, + ) + defer.returnValue(res) + + @defer.inlineCallbacks + def create_group(self, group_id, user_id, name, avatar_url, short_description, + long_description,): + yield self._simple_insert( + table="groups", + values={ + "group_id": group_id, + "name": name, + "avatar_url": avatar_url, + "short_description": short_description, + "long_description": long_description, + "is_public": True, + }, + desc="create_group", + ) + + @defer.inlineCallbacks + def update_group_profile(self, group_id, profile,): + yield self._simple_update_one( + table="groups", + keyvalues={ + "group_id": group_id, + }, + updatevalues=profile, + desc="update_group_profile", + ) + + def get_attestations_need_renewals(self, valid_until_ms): + """Get all attestations that need to be renewed until givent time + """ + def _get_attestations_need_renewals_txn(txn): + sql = """ + SELECT group_id, user_id FROM group_attestations_renewals + WHERE valid_until_ms <= ? + """ + txn.execute(sql, (valid_until_ms,)) + return self.cursor_to_dict(txn) + return self.runInteraction( + "get_attestations_need_renewals", _get_attestations_need_renewals_txn + ) + + def update_attestation_renewal(self, group_id, user_id, attestation): + """Update an attestation that we have renewed + """ + return self._simple_update_one( + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + }, + desc="update_attestation_renewal", + ) + + def update_remote_attestion(self, group_id, user_id, attestation): + """Update an attestation that a remote has renewed + """ + return self._simple_update_one( + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + updatevalues={ + "valid_until_ms": attestation["valid_until_ms"], + "attestation_json": json.dumps(attestation) + }, + desc="update_remote_attestion", + ) + + def remove_attestation_renewal(self, group_id, user_id): + """Remove an attestation that we thought we should renew, but actually + shouldn't. Ideally this would never get called as we would never + incorrectly try and do attestations for local users on local groups. + + Args: + group_id (str) + user_id (str) + """ + return self._simple_delete( + table="group_attestations_renewals", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + desc="remove_attestation_renewal", + ) + + @defer.inlineCallbacks + def get_remote_attestation(self, group_id, user_id): + """Get the attestation that proves the remote agrees that the user is + in the group. + """ + row = yield self._simple_select_one( + table="group_attestations_remote", + keyvalues={ + "group_id": group_id, + "user_id": user_id, + }, + retcols=("valid_until_ms", "attestation_json"), + desc="get_remote_attestation", + allow_none=True, + ) + + now = int(self._clock.time_msec()) + if row and now < row["valid_until_ms"]: + defer.returnValue(json.loads(row["attestation_json"])) + + defer.returnValue(None) + + def get_joined_groups(self, user_id): + return self._simple_select_onecol( + table="local_group_membership", + keyvalues={ + "user_id": user_id, + "membership": "join", + }, + retcol="group_id", + desc="get_joined_groups", + ) + + def get_all_groups_for_user(self, user_id, now_token): + def _get_all_groups_for_user_txn(txn): + sql = """ + SELECT group_id, type, membership, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND membership != 'leave' + AND stream_id <= ? + """ + txn.execute(sql, (user_id, now_token,)) + return [ + { + "group_id": row[0], + "type": row[1], + "membership": row[2], + "content": json.loads(row[3]), + } + for row in txn + ] + return self.runInteraction( + "get_all_groups_for_user", _get_all_groups_for_user_txn, + ) + + def get_groups_changes_for_user(self, user_id, from_token, to_token): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_entity_changed( + user_id, from_token, + ) + if not has_changed: + return [] + + def _get_groups_changes_for_user_txn(txn): + sql = """ + SELECT group_id, membership, type, u.content + FROM local_group_updates AS u + INNER JOIN local_group_membership USING (group_id, user_id) + WHERE user_id = ? AND ? < stream_id AND stream_id <= ? + """ + txn.execute(sql, (user_id, from_token, to_token,)) + return [{ + "group_id": group_id, + "membership": membership, + "type": gtype, + "content": json.loads(content_json), + } for group_id, membership, gtype, content_json in txn] + return self.runInteraction( + "get_groups_changes_for_user", _get_groups_changes_for_user_txn, + ) + + def get_all_groups_changes(self, from_token, to_token, limit): + from_token = int(from_token) + has_changed = self._group_updates_stream_cache.has_any_entity_changed( + from_token, + ) + if not has_changed: + return [] + + def _get_all_groups_changes_txn(txn): + sql = """ + SELECT stream_id, group_id, user_id, type, content + FROM local_group_updates + WHERE ? < stream_id AND stream_id <= ? + LIMIT ? + """ + txn.execute(sql, (from_token, to_token, limit,)) + return [( + stream_id, + group_id, + user_id, + gtype, + json.loads(content_json), + ) for stream_id, group_id, user_id, gtype, content_json in txn] + return self.runInteraction( + "get_all_groups_changes", _get_all_groups_changes_txn, + ) + + def get_group_stream_token(self): + return self._group_updates_id_gen.get_current_token() diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 86b37b9ddd..87aeaf71d6 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore): key_ids Args: server_name (str): The name of the server. - key_ids (list of str): List of key_ids to try and look up. + key_ids (iterable[str]): key_ids to try and look up. Returns: - (list of VerifyKey): The verification keys. + Deferred: resolves to dict[str, VerifyKey]: map from + key_id to verification key. """ keys = {} for key_id in key_ids: @@ -112,30 +113,37 @@ class KeyStore(SQLBaseStore): keys[key_id] = key defer.returnValue(keys) - @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, verify_key): """Stores a NACL verification key for the given server. Args: server_name (str): The name of the server. - key_id (str): The version of the key for the server. from_server (str): Where the verification key was looked up - ts_now_ms (int): The time now in milliseconds - verification_key (VerifyKey): The NACL verify key. + time_now_ms (int): The time now in milliseconds + verify_key (nacl.signing.VerifyKey): The NACL verify key. """ - yield self._simple_upsert( - table="server_signature_keys", - keyvalues={ - "server_name": server_name, - "key_id": "%s:%s" % (verify_key.alg, verify_key.version), - }, - values={ - "from_server": from_server, - "ts_added_ms": time_now_ms, - "verify_key": buffer(verify_key.encode()), - }, - desc="store_server_verify_key", - ) + key_id = "%s:%s" % (verify_key.alg, verify_key.version) + + def _txn(txn): + self._simple_upsert_txn( + txn, + table="server_signature_keys", + keyvalues={ + "server_name": server_name, + "key_id": key_id, + }, + values={ + "from_server": from_server, + "ts_added_ms": time_now_ms, + "verify_key": buffer(verify_key.encode()), + }, + ) + txn.call_after( + self._get_server_verify_key.invalidate, + (server_name, key_id) + ) + + return self.runInteraction("store_server_verify_key", _txn) def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 4c0f82353d..e6cdbb0545 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -12,15 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from synapse.storage.background_updates import BackgroundUpdateStore -from ._base import SQLBaseStore - -class MediaRepositoryStore(SQLBaseStore): +class MediaRepositoryStore(BackgroundUpdateStore): """Persistence for attachments and avatars""" - def get_default_thumbnails(self, top_level_type, sub_type): - return [] + def __init__(self, db_conn, hs): + super(MediaRepositoryStore, self).__init__(db_conn, hs) + + self.register_background_index_update( + update_name='local_media_repository_url_idx', + index_name='local_media_repository_url_idx', + table='local_media_repository', + columns=['created_ts'], + where_clause='url_cache IS NOT NULL', + ) def get_local_media(self, media_id): """Get the metadata for a local piece of media @@ -30,13 +37,16 @@ class MediaRepositoryStore(SQLBaseStore): return self._simple_select_one( "local_media_repository", {"media_id": media_id}, - ("media_type", "media_length", "upload_name", "created_ts"), + ( + "media_type", "media_length", "upload_name", "created_ts", + "quarantined_by", "url_cache", + ), allow_none=True, desc="get_local_media", ) def store_local_media(self, media_id, media_type, time_now_ms, upload_name, - media_length, user_id): + media_length, user_id, url_cache=None): return self._simple_insert( "local_media_repository", { @@ -46,6 +56,7 @@ class MediaRepositoryStore(SQLBaseStore): "upload_name": upload_name, "media_length": media_length, "user_id": user_id.to_string(), + "url_cache": url_cache, }, desc="store_local_media", ) @@ -58,7 +69,7 @@ class MediaRepositoryStore(SQLBaseStore): def get_url_cache_txn(txn): # get the most recently cached result (relative to the given ts) sql = ( - "SELECT response_code, etag, expires, og, media_id, download_ts" + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts <= ?" " ORDER BY download_ts DESC LIMIT 1" @@ -70,7 +81,7 @@ class MediaRepositoryStore(SQLBaseStore): # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) sql = ( - "SELECT response_code, etag, expires, og, media_id, download_ts" + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts > ?" " ORDER BY download_ts ASC LIMIT 1" @@ -82,14 +93,14 @@ class MediaRepositoryStore(SQLBaseStore): return None return dict(zip(( - 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts' + 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts' ), row)) return self.runInteraction( "get_url_cache", get_url_cache_txn ) - def store_url_cache(self, url, response_code, etag, expires, og, media_id, + def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id, download_ts): return self._simple_insert( "local_media_repository_url_cache", @@ -97,7 +108,7 @@ class MediaRepositoryStore(SQLBaseStore): "url": url, "response_code": response_code, "etag": etag, - "expires": expires, + "expires_ts": expires_ts, "og": og, "media_id": media_id, "download_ts": download_ts, @@ -138,7 +149,7 @@ class MediaRepositoryStore(SQLBaseStore): {"media_origin": origin, "media_id": media_id}, ( "media_type", "media_length", "upload_name", "created_ts", - "filesystem_id", + "filesystem_id", "quarantined_by", ), allow_none=True, desc="get_cached_remote_media", @@ -162,7 +173,14 @@ class MediaRepositoryStore(SQLBaseStore): desc="store_cached_remote_media", ) - def update_cached_last_access_time(self, origin_id_tuples, time_ts): + def update_cached_last_access_time(self, local_media, remote_media, time_ms): + """Updates the last access time of the given media + + Args: + local_media (iterable[str]): Set of media_ids + remote_media (iterable[(str, str)]): Set of (server_name, media_id) + time_ms: Current time in milliseconds + """ def update_cache_txn(txn): sql = ( "UPDATE remote_media_cache SET last_access_ts = ?" @@ -170,8 +188,18 @@ class MediaRepositoryStore(SQLBaseStore): ) txn.executemany(sql, ( - (time_ts, media_origin, media_id) - for media_origin, media_id in origin_id_tuples + (time_ms, media_origin, media_id) + for media_origin, media_id in remote_media + )) + + sql = ( + "UPDATE local_media_repository SET last_access_ts = ?" + " WHERE media_id = ?" + ) + + txn.executemany(sql, ( + (time_ms, media_id) + for media_id in local_media )) return self.runInteraction("update_cached_last_access_time", update_cache_txn) @@ -234,3 +262,70 @@ class MediaRepositoryStore(SQLBaseStore): }, ) return self.runInteraction("delete_remote_media", delete_remote_media_txn) + + def get_expired_url_cache(self, now_ts): + sql = ( + "SELECT media_id FROM local_media_repository_url_cache" + " WHERE expires_ts < ?" + " ORDER BY expires_ts ASC" + " LIMIT 500" + ) + + def _get_expired_url_cache_txn(txn): + txn.execute(sql, (now_ts,)) + return [row[0] for row in txn] + + return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + + def delete_url_cache(self, media_ids): + if len(media_ids) == 0: + return + + sql = ( + "DELETE FROM local_media_repository_url_cache" + " WHERE media_id = ?" + ) + + def _delete_url_cache_txn(txn): + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + + def get_url_cache_media_before(self, before_ts): + sql = ( + "SELECT media_id FROM local_media_repository" + " WHERE created_ts < ? AND url_cache IS NOT NULL" + " ORDER BY created_ts ASC" + " LIMIT 500" + ) + + def _get_url_cache_media_before_txn(txn): + txn.execute(sql, (before_ts,)) + return [row[0] for row in txn] + + return self.runInteraction( + "get_url_cache_media_before", _get_url_cache_media_before_txn, + ) + + def delete_url_cache_media(self, media_ids): + if len(media_ids) == 0: + return + + def _delete_url_cache_media_txn(txn): + sql = ( + "DELETE FROM local_media_repository" + " WHERE media_id = ?" + ) + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + sql = ( + "DELETE FROM local_media_repository_thumbnails" + " WHERE media_id = ?" + ) + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction( + "delete_url_cache_media", _delete_url_cache_media_txn, + ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b357f22be7..04411a665f 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,7 +26,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 40 +SCHEMA_VERSION = 48 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -44,6 +45,13 @@ def prepare_database(db_conn, database_engine, config): If `config` is None then prepare_database will assert that no upgrade is necessary, *or* will create a fresh database if the database is empty. + + Args: + db_conn: + database_engine: + config (synapse.config.homeserver.HomeServerConfig|None): + application config, or None if we are connecting to an existing + database which we expect to be configured already """ try: cur = db_conn.cursor() @@ -64,9 +72,13 @@ def prepare_database(db_conn, database_engine, config): else: _setup_new_database(cur, database_engine) + # check if any of our configured dynamic modules want a database + if config is not None: + _apply_module_schemas(cur, database_engine, config) + cur.close() db_conn.commit() - except: + except Exception: db_conn.rollback() raise @@ -283,6 +295,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, ) +def _apply_module_schemas(txn, database_engine, config): + """Apply the module schemas for the dynamic modules, if any + + Args: + cur: database cursor + database_engine: synapse database engine class + config (synapse.config.homeserver.HomeServerConfig): + application config + """ + for (mod, _config) in config.password_providers: + if not hasattr(mod, 'get_db_schema_files'): + continue + modname = ".".join((mod.__module__, mod.__name__)) + _apply_module_schema_files( + txn, database_engine, modname, mod.get_db_schema_files(), + ) + + +def _apply_module_schema_files(cur, database_engine, modname, names_and_streams): + """Apply the module schemas for a single module + + Args: + cur: database cursor + database_engine: synapse database engine class + modname (str): fully qualified name of the module + names_and_streams (Iterable[(str, file)]): the names and streams of + schemas to be applied + """ + cur.execute( + database_engine.convert_param_style( + "SELECT file FROM applied_module_schemas WHERE module_name = ?" + ), + (modname,) + ) + applied_deltas = set(d for d, in cur) + for (name, stream) in names_and_streams: + if name in applied_deltas: + continue + + root_name, ext = os.path.splitext(name) + if ext != '.sql': + raise PrepareDatabaseException( + "only .sql files are currently supported for module schemas", + ) + + logger.info("applying schema %s for %s", name, modname) + for statement in get_statements(stream): + cur.execute(statement) + + # Mark as done. + cur.execute( + database_engine.convert_param_style( + "INSERT INTO applied_module_schemas (module_name, file)" + " VALUES (?,?)", + ), + (modname, name) + ) + + def get_statements(f): statement_buffer = "" in_comment = False # If we're in a /* ... */ style comment @@ -356,7 +427,7 @@ def _get_or_create_schema_state(txn, database_engine): ), (current_version,) ) - applied_deltas = [d for d, in txn.fetchall()] + applied_deltas = [d for d, in txn] return current_version, applied_deltas, upgraded return None diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 4d1590d2b4..9e9d3c2591 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -85,8 +85,8 @@ class PresenceStore(SQLBaseStore): self.presence_stream_cache.entity_has_changed, state.user_id, stream_id, ) - self._invalidate_cache_and_stream( - txn, self._get_presence_for_user, (state.user_id,) + txn.call_after( + self._get_presence_for_user.invalidate, (state.user_id,) ) # Actually insert new rows diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py index 26a40905ae..8612bd5ecc 100644 --- a/synapse/storage/profile.py +++ b/synapse/storage/profile.py @@ -13,15 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.internet import defer + +from synapse.storage.roommember import ProfileInfo +from synapse.api.errors import StoreError + from ._base import SQLBaseStore -class ProfileStore(SQLBaseStore): - def create_profile(self, user_localpart): - return self._simple_insert( - table="profiles", - values={"user_id": user_localpart}, - desc="create_profile", +class ProfileWorkerStore(SQLBaseStore): + @defer.inlineCallbacks + def get_profileinfo(self, user_localpart): + try: + profile = yield self._simple_select_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcols=("displayname", "avatar_url"), + desc="get_profileinfo", + ) + except StoreError as e: + if e.code == 404: + # no match + defer.returnValue(ProfileInfo(None, None)) + return + else: + raise + + defer.returnValue( + ProfileInfo( + avatar_url=profile['avatar_url'], + display_name=profile['displayname'], + ) ) def get_profile_displayname(self, user_localpart): @@ -32,14 +54,6 @@ class ProfileStore(SQLBaseStore): desc="get_profile_displayname", ) - def set_profile_displayname(self, user_localpart, new_displayname): - return self._simple_update_one( - table="profiles", - keyvalues={"user_id": user_localpart}, - updatevalues={"displayname": new_displayname}, - desc="set_profile_displayname", - ) - def get_profile_avatar_url(self, user_localpart): return self._simple_select_one_onecol( table="profiles", @@ -48,6 +62,32 @@ class ProfileStore(SQLBaseStore): desc="get_profile_avatar_url", ) + def get_from_remote_profile_cache(self, user_id): + return self._simple_select_one( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + retcols=("displayname", "avatar_url",), + allow_none=True, + desc="get_from_remote_profile_cache", + ) + + +class ProfileStore(ProfileWorkerStore): + def create_profile(self, user_localpart): + return self._simple_insert( + table="profiles", + values={"user_id": user_localpart}, + desc="create_profile", + ) + + def set_profile_displayname(self, user_localpart, new_displayname): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"displayname": new_displayname}, + desc="set_profile_displayname", + ) + def set_profile_avatar_url(self, user_localpart, new_avatar_url): return self._simple_update_one( table="profiles", @@ -55,3 +95,90 @@ class ProfileStore(SQLBaseStore): updatevalues={"avatar_url": new_avatar_url}, desc="set_profile_avatar_url", ) + + def add_remote_profile_cache(self, user_id, displayname, avatar_url): + """Ensure we are caching the remote user's profiles. + + This should only be called when `is_subscribed_remote_profile_for_user` + would return true for the user. + """ + return self._simple_upsert( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="add_remote_profile_cache", + ) + + def update_remote_profile_cache(self, user_id, displayname, avatar_url): + return self._simple_update( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + values={ + "displayname": displayname, + "avatar_url": avatar_url, + "last_check": self._clock.time_msec(), + }, + desc="update_remote_profile_cache", + ) + + @defer.inlineCallbacks + def maybe_delete_remote_profile_cache(self, user_id): + """Check if we still care about the remote user's profile, and if we + don't then remove their profile from the cache + """ + subscribed = yield self.is_subscribed_remote_profile_for_user(user_id) + if not subscribed: + yield self._simple_delete( + table="remote_profile_cache", + keyvalues={"user_id": user_id}, + desc="delete_remote_profile_cache", + ) + + def get_remote_profile_cache_entries_that_expire(self, last_checked): + """Get all users who haven't been checked since `last_checked` + """ + def _get_remote_profile_cache_entries_that_expire_txn(txn): + sql = """ + SELECT user_id, displayname, avatar_url + FROM remote_profile_cache + WHERE last_check < ? + """ + + txn.execute(sql, (last_checked,)) + + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_remote_profile_cache_entries_that_expire", + _get_remote_profile_cache_entries_that_expire_txn, + ) + + @defer.inlineCallbacks + def is_subscribed_remote_profile_for_user(self, user_id): + """Check whether we are interested in a remote user's profile. + """ + res = yield self._simple_select_one_onecol( + table="group_users", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + defer.returnValue(True) + + res = yield self._simple_select_one_onecol( + table="group_invites", + keyvalues={"user_id": user_id}, + retcol="user_id", + allow_none=True, + desc="should_update_remote_profile_cache_for_user", + ) + + if res: + defer.returnValue(True) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index cbec255966..04a0b59a39 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,10 +15,17 @@ # limitations under the License. from ._base import SQLBaseStore +from synapse.storage.appservice import ApplicationServiceWorkerStore +from synapse.storage.pusher import PusherWorkerStore +from synapse.storage.receipts import ReceiptsWorkerStore +from synapse.storage.roommember import RoomMemberWorkerStore from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList +from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.push.baserules import list_with_base_rules +from synapse.api.constants import EventTypes from twisted.internet import defer +import abc import logging import simplejson as json @@ -47,8 +55,44 @@ def _load_rules(rawrules, enabled_map): return rules -class PushRuleStore(SQLBaseStore): - @cachedInlineCallbacks() +class PushRulesWorkerStore(ApplicationServiceWorkerStore, + ReceiptsWorkerStore, + PusherWorkerStore, + RoomMemberWorkerStore, + SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_push_rules_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(PushRulesWorkerStore, self).__init__(db_conn, hs) + + push_rules_prefill, push_rules_id = self._get_cache_dict( + db_conn, "push_rules_stream", + entity_column="user_id", + stream_column="stream_id", + max_value=self.get_max_push_rules_stream_id(), + ) + + self.push_rules_stream_cache = StreamChangeCache( + "PushRulesStreamChangeCache", push_rules_id, + prefilled_cache=push_rules_prefill, + ) + + @abc.abstractmethod + def get_max_push_rules_stream_id(self): + """Get the position of the push rules stream. + + Returns: + int + """ + raise NotImplementedError() + + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_for_user(self, user_id): rows = yield self._simple_select_list( table="push_rules", @@ -72,7 +116,7 @@ class PushRuleStore(SQLBaseStore): defer.returnValue(rules) - @cachedInlineCallbacks() + @cachedInlineCallbacks(max_entries=5000) def get_push_rules_enabled_for_user(self, user_id): results = yield self._simple_select_list( table="push_rules_enable", @@ -88,6 +132,22 @@ class PushRuleStore(SQLBaseStore): r['rule_id']: False if r['enabled'] == 0 else True for r in results }) + def have_push_rules_changed_for_user(self, user_id, last_id): + if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): + return defer.succeed(False) + else: + def have_push_rules_changed_txn(txn): + sql = ( + "SELECT COUNT(stream_id) FROM push_rules_stream" + " WHERE user_id = ? AND ? < stream_id" + ) + txn.execute(sql, (user_id, last_id)) + count, = txn.fetchone() + return bool(count) + return self.runInteraction( + "have_push_rules_changed", have_push_rules_changed_txn + ) + @cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1, inlineCallbacks=True) def bulk_get_push_rules(self, user_ids): @@ -184,6 +244,18 @@ class PushRuleStore(SQLBaseStore): if uid in local_users_in_room: user_ids.add(uid) + forgotten = yield self.who_forgot_in_room( + event.room_id, on_invalidate=cache_context.invalidate, + ) + + for row in forgotten: + user_id = row["user_id"] + event_id = row["event_id"] + + mem_id = current_state_ids.get((EventTypes.Member, user_id), None) + if event_id == mem_id: + user_ids.discard(user_id) + rules_by_user = yield self.bulk_get_push_rules( user_ids, on_invalidate=cache_context.invalidate, ) @@ -215,6 +287,8 @@ class PushRuleStore(SQLBaseStore): results.setdefault(row['user_name'], {})[row['rule_id']] = enabled defer.returnValue(results) + +class PushRuleStore(PushRulesWorkerStore): @defer.inlineCallbacks def add_push_rule( self, user_id, rule_id, priority_class, conditions, actions, @@ -513,21 +587,8 @@ class PushRuleStore(SQLBaseStore): room stream ordering it corresponds to.""" return self._push_rules_stream_id_gen.get_current_token() - def have_push_rules_changed_for_user(self, user_id, last_id): - if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) - else: - def have_push_rules_changed_txn(txn): - sql = ( - "SELECT COUNT(stream_id) FROM push_rules_stream" - " WHERE user_id = ? AND ? < stream_id" - ) - txn.execute(sql, (user_id, last_id)) - count, = txn.fetchone() - return bool(count) - return self.runInteraction( - "have_push_rules_changed", have_push_rules_changed_txn - ) + def get_max_push_rules_stream_id(self): + return self.get_push_rules_stream_token()[0] class RuleNotFoundException(Exception): diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py index 8cc9f0353b..307660b99a 100644 --- a/synapse/storage/pusher.py +++ b/synapse/storage/pusher.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -27,7 +28,7 @@ import types logger = logging.getLogger(__name__) -class PusherStore(SQLBaseStore): +class PusherWorkerStore(SQLBaseStore): def _decode_pushers_rows(self, rows): for r in rows: dataJson = r['data'] @@ -102,9 +103,6 @@ class PusherStore(SQLBaseStore): rows = yield self.runInteraction("get_all_pushers", get_pushers) defer.returnValue(rows) - def get_pushers_stream_token(self): - return self._pushers_id_gen.get_current_token() - def get_all_updated_pushers(self, last_id, current_id, limit): if last_id == current_id: return defer.succeed(([], [])) @@ -135,6 +133,48 @@ class PusherStore(SQLBaseStore): "get_all_updated_pushers", get_all_updated_pushers_txn ) + def get_all_updated_pushers_rows(self, last_id, current_id, limit): + """Get all the pushers that have changed between the given tokens. + + Returns: + Deferred(list(tuple)): each tuple consists of: + stream_id (str) + user_id (str) + app_id (str) + pushkey (str) + was_deleted (bool): whether the pusher was added/updated (False) + or deleted (True) + """ + + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_pushers_rows_txn(txn): + sql = ( + "SELECT id, user_name, app_id, pushkey" + " FROM pushers" + " WHERE ? < id AND id <= ?" + " ORDER BY id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + results = [list(row) + [False] for row in txn] + + sql = ( + "SELECT stream_id, user_id, app_id, pushkey" + " FROM deleted_pushers" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC LIMIT ?" + ) + txn.execute(sql, (last_id, current_id, limit)) + + results.extend(list(row) + [True] for row in txn) + results.sort() # Sort so that they're ordered by stream id + + return results + return self.runInteraction( + "get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn + ) + @cachedInlineCallbacks(num_args=1, max_entries=15000) def get_if_user_has_pusher(self, user_id): # This only exists for the cachedList decorator @@ -156,56 +196,74 @@ class PusherStore(SQLBaseStore): defer.returnValue(result) + +class PusherStore(PusherWorkerStore): + def get_pushers_stream_token(self): + return self._pushers_id_gen.get_current_token() + @defer.inlineCallbacks def add_pusher(self, user_id, access_token, kind, app_id, app_display_name, device_display_name, pushkey, pushkey_ts, lang, data, last_stream_ordering, profile_tag=""): with self._pushers_id_gen.get_next() as stream_id: - def f(txn): - newly_inserted = self._simple_upsert_txn( - txn, - "pushers", - { - "app_id": app_id, - "pushkey": pushkey, - "user_name": user_id, - }, - { - "access_token": access_token, - "kind": kind, - "app_display_name": app_display_name, - "device_display_name": device_display_name, - "ts": pushkey_ts, - "lang": lang, - "data": encode_canonical_json(data), - "last_stream_ordering": last_stream_ordering, - "profile_tag": profile_tag, - "id": stream_id, - }, - ) - if newly_inserted: - # get_if_user_has_pusher only cares if the user has - # at least *one* pusher. - txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + # no need to lock because `pushers` has a unique key on + # (app_id, pushkey, user_name) so _simple_upsert will retry + newly_inserted = yield self._simple_upsert( + table="pushers", + keyvalues={ + "app_id": app_id, + "pushkey": pushkey, + "user_name": user_id, + }, + values={ + "access_token": access_token, + "kind": kind, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "ts": pushkey_ts, + "lang": lang, + "data": encode_canonical_json(data), + "last_stream_ordering": last_stream_ordering, + "profile_tag": profile_tag, + "id": stream_id, + }, + desc="add_pusher", + lock=False, + ) - yield self.runInteraction("add_pusher", f) + if newly_inserted: + self.runInteraction( + "add_pusher", + self._invalidate_cache_and_stream, + self.get_if_user_has_pusher, (user_id,) + ) @defer.inlineCallbacks def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id): def delete_pusher_txn(txn, stream_id): - txn.call_after(self.get_if_user_has_pusher.invalidate, (user_id,)) + self._invalidate_cache_and_stream( + txn, self.get_if_user_has_pusher, (user_id,) + ) self._simple_delete_one_txn( txn, "pushers", {"app_id": app_id, "pushkey": pushkey, "user_name": user_id} ) - self._simple_upsert_txn( + + # it's possible for us to end up with duplicate rows for + # (app_id, pushkey, user_id) at different stream_ids, but that + # doesn't really matter. + self._simple_insert_txn( txn, - "deleted_pushers", - {"app_id": app_id, "pushkey": pushkey, "user_id": user_id}, - {"stream_id": stream_id}, + table="deleted_pushers", + values={ + "stream_id": stream_id, + "app_id": app_id, + "pushkey": pushkey, + "user_id": user_id, + }, ) with self._pushers_id_gen.get_next() as stream_id: @@ -268,9 +326,12 @@ class PusherStore(SQLBaseStore): @defer.inlineCallbacks def set_throttle_params(self, pusher_id, room_id, params): + # no need to lock because `pusher_throttle` has a primary key on + # (pusher, room_id) so _simple_upsert will retry yield self._simple_upsert( "pusher_throttle", {"pusher": pusher_id, "room_id": room_id}, params, - desc="set_throttle_params" + desc="set_throttle_params", + lock=False, ) diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index f72d15f5ed..63997ed449 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,46 +15,50 @@ # limitations under the License. from ._base import SQLBaseStore +from .util.id_generators import StreamIdGenerator from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached from synapse.util.caches.stream_change_cache import StreamChangeCache from twisted.internet import defer +import abc import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) -class ReceiptsStore(SQLBaseStore): - def __init__(self, hs): - super(ReceiptsStore, self).__init__(hs) +class ReceiptsWorkerStore(SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_max_receipt_stream_id` which can be called in the initializer. + """ + + # This ABCMeta metaclass ensures that we cannot be instantiated without + # the abstract methods being implemented. + __metaclass__ = abc.ABCMeta + + def __init__(self, db_conn, hs): + super(ReceiptsWorkerStore, self).__init__(db_conn, hs) self._receipts_stream_cache = StreamChangeCache( - "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token() + "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) + @abc.abstractmethod + def get_max_receipt_stream_id(self): + """Get the current max stream ID for receipts stream + + Returns: + int + """ + raise NotImplementedError() + @cachedInlineCallbacks() def get_users_with_read_receipts_in_room(self, room_id): receipts = yield self.get_receipts_for_room(room_id, "m.read") defer.returnValue(set(r['user_id'] for r in receipts)) - def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, - user_id): - if receipt_type != "m.read": - return - - # Returns an ObservableDeferred - res = self.get_users_with_read_receipts_in_room.cache.get((room_id,), None) - - if res and res.called and user_id in res.result: - # We'd only be adding to the set, so no point invalidating if the - # user is already there - return - - self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - @cached(num_args=2) def get_receipts_for_room(self, room_id, receipt_type): return self._simple_select_list( @@ -265,6 +270,59 @@ class ReceiptsStore(SQLBaseStore): } defer.returnValue(results) + def get_all_updated_receipts(self, last_id, current_id, limit=None): + if last_id == current_id: + return defer.succeed([]) + + def get_all_updated_receipts_txn(txn): + sql = ( + "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" + " FROM receipts_linearized" + " WHERE ? < stream_id AND stream_id <= ?" + " ORDER BY stream_id ASC" + ) + args = [last_id, current_id] + if limit is not None: + sql += " LIMIT ?" + args.append(limit) + txn.execute(sql, args) + + return txn.fetchall() + return self.runInteraction( + "get_all_updated_receipts", get_all_updated_receipts_txn + ) + + def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type, + user_id): + if receipt_type != "m.read": + return + + # Returns an ObservableDeferred + res = self.get_users_with_read_receipts_in_room.cache.get( + room_id, None, update_metrics=False, + ) + + if res: + if isinstance(res, defer.Deferred) and res.called: + res = res.result + if user_id in res: + # We'd only be adding to the set, so no point invalidating if the + # user is already there + return + + self.get_users_with_read_receipts_in_room.invalidate((room_id,)) + + +class ReceiptsStore(ReceiptsWorkerStore): + def __init__(self, db_conn, hs): + # We instantiate this first as the ReceiptsWorkerStore constructor + # needs to be able to call get_max_receipt_stream_id + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + + super(ReceiptsStore, self).__init__(db_conn, hs) + def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() @@ -313,10 +371,9 @@ class ReceiptsStore(SQLBaseStore): ) txn.execute(sql, (room_id, receipt_type, user_id)) - results = txn.fetchall() - if results and topological_ordering: - for to, so, _ in results: + if topological_ordering: + for to, so, _ in txn: if int(to) > topological_ordering: return False elif int(to) == topological_ordering and int(so) >= stream_ordering: @@ -351,6 +408,7 @@ class ReceiptsStore(SQLBaseStore): room_id=room_id, user_id=user_id, topological_ordering=topological_ordering, + stream_ordering=stream_ordering, ) return True @@ -452,25 +510,3 @@ class ReceiptsStore(SQLBaseStore): "data": json.dumps(data), } ) - - def get_all_updated_receipts(self, last_id, current_id, limit=None): - if last_id == current_id: - return defer.succeed([]) - - def get_all_updated_receipts_txn(txn): - sql = ( - "SELECT stream_id, room_id, receipt_type, user_id, event_id, data" - " FROM receipts_linearized" - " WHERE ? < stream_id AND stream_id <= ?" - " ORDER BY stream_id ASC" - ) - args = [last_id, current_id] - if limit is not None: - sql += " LIMIT ?" - args.append(limit) - txn.execute(sql, args) - - return txn.fetchall() - return self.runInteraction( - "get_all_updated_receipts", get_all_updated_receipts_txn - ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26be6060c3..a50717db2d 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -19,13 +19,75 @@ from twisted.internet import defer from synapse.api.errors import StoreError, Codes from synapse.storage import background_updates +from synapse.storage._base import SQLBaseStore from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from six.moves import range -class RegistrationStore(background_updates.BackgroundUpdateStore): - def __init__(self, hs): - super(RegistrationStore, self).__init__(hs) +class RegistrationWorkerStore(SQLBaseStore): + @cached() + def get_user_by_id(self, user_id): + return self._simple_select_one( + table="users", + keyvalues={ + "name": user_id, + }, + retcols=["name", "password_hash", "is_guest"], + allow_none=True, + desc="get_user_by_id", + ) + + @cached() + def get_user_by_access_token(self, token): + """Get a user from the given access token. + + Args: + token (str): The access token of a user. + Returns: + defer.Deferred: None, if the token did not match, otherwise dict + including the keys `name`, `is_guest`, `device_id`, `token_id`. + """ + return self.runInteraction( + "get_user_by_access_token", + self._query_for_auth, + token + ) + + @defer.inlineCallbacks + def is_server_admin(self, user): + res = yield self._simple_select_one_onecol( + table="users", + keyvalues={"name": user.to_string()}, + retcol="admin", + allow_none=True, + desc="is_server_admin", + ) + + defer.returnValue(res if res else False) + + def _query_for_auth(self, txn, token): + sql = ( + "SELECT users.name, users.is_guest, access_tokens.id as token_id," + " access_tokens.device_id" + " FROM users" + " INNER JOIN access_tokens on users.name = access_tokens.user_id" + " WHERE token = ?" + ) + + txn.execute(sql, (token,)) + rows = self.cursor_to_dict(txn) + if rows: + return rows[0] + + return None + + +class RegistrationStore(RegistrationWorkerStore, + background_updates.BackgroundUpdateStore): + + def __init__(self, db_conn, hs): + super(RegistrationStore, self).__init__(db_conn, hs) self.clock = hs.get_clock() @@ -36,12 +98,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): columns=["user_id", "device_id"], ) - self.register_background_index_update( - "refresh_tokens_device_index", - index_name="refresh_tokens_device_id", - table="refresh_tokens", - columns=["user_id", "device_id"], - ) + # we no longer use refresh tokens, but it's possible that some people + # might have a background update queued to build this index. Just + # clear the background update. + self.register_noop_background_update("refresh_tokens_device_index") @defer.inlineCallbacks def add_access_token_to_user(self, user_id, token, device_id=None): @@ -177,9 +237,11 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) if create_profile_with_localpart: + # set a default displayname serverside to avoid ugly race + # between auto-joins and clients trying to set displaynames txn.execute( - "INSERT INTO profiles(user_id) VALUES (?)", - (create_profile_with_localpart,) + "INSERT INTO profiles(user_id, displayname) VALUES (?,?)", + (create_profile_with_localpart, create_profile_with_localpart) ) self._invalidate_cache_and_stream( @@ -187,18 +249,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): ) txn.call_after(self.is_guest.invalidate, (user_id,)) - @cached() - def get_user_by_id(self, user_id): - return self._simple_select_one( - table="users", - keyvalues={ - "name": user_id, - }, - retcols=["name", "password_hash", "is_guest"], - allow_none=True, - desc="get_user_by_id", - ) - def get_users_by_id_case_insensitive(self, user_id): """Gets users that match user_id case insensitively. Returns a mapping of user_id -> password_hash. @@ -209,7 +259,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): " WHERE lower(name) = lower(?)" ) txn.execute(sql, (user_id,)) - return dict(txn.fetchall()) + return dict(txn) return self.runInteraction("get_users_by_id_case_insensitive", f) @@ -236,12 +286,10 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): "user_set_password_hash", user_set_password_hash_txn ) - @defer.inlineCallbacks def user_delete_access_tokens(self, user_id, except_token_id=None, - device_id=None, - delete_refresh_tokens=False): + device_id=None): """ - Invalidate access/refresh tokens belonging to a user + Invalidate access tokens belonging to a user Args: user_id (str): ID of user the tokens belong to @@ -250,10 +298,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): device_id (str|None): ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - delete_refresh_tokens (bool): True to delete refresh tokens as - well as access tokens. Returns: - defer.Deferred: + defer.Deferred[list[str, int, str|None, int]]: a list of + (token, token id, device id) for each of the deleted tokens """ def f(txn): keyvalues = { @@ -262,13 +309,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): if device_id is not None: keyvalues["device_id"] = device_id - if delete_refresh_tokens: - self._simple_delete_txn( - txn, - table="refresh_tokens", - keyvalues=keyvalues, - ) - items = keyvalues.items() where_clause = " AND ".join(k + " = ?" for k, _ in items) values = [v for _, v in items] @@ -277,14 +317,14 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values.append(except_token_id) txn.execute( - "SELECT token FROM access_tokens WHERE %s" % where_clause, + "SELECT token, id, device_id FROM access_tokens WHERE %s" % where_clause, values ) - rows = self.cursor_to_dict(txn) + tokens_and_devices = [(r[0], r[1], r[2]) for r in txn] - for row in rows: + for token, _, _ in tokens_and_devices: self._invalidate_cache_and_stream( - txn, self.get_user_by_access_token, (row["token"],) + txn, self.get_user_by_access_token, (token,) ) txn.execute( @@ -292,7 +332,9 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): values ) - yield self.runInteraction( + return tokens_and_devices + + return self.runInteraction( "user_delete_access_tokens", f, ) @@ -312,34 +354,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): return self.runInteraction("delete_access_token", f) - @cached() - def get_user_by_access_token(self, token): - """Get a user from the given access token. - - Args: - token (str): The access token of a user. - Returns: - defer.Deferred: None, if the token did not match, otherwise dict - including the keys `name`, `is_guest`, `device_id`, `token_id`. - """ - return self.runInteraction( - "get_user_by_access_token", - self._query_for_auth, - token - ) - - @defer.inlineCallbacks - def is_server_admin(self, user): - res = yield self._simple_select_one_onecol( - table="users", - keyvalues={"name": user.to_string()}, - retcol="admin", - allow_none=True, - desc="is_server_admin", - ) - - defer.returnValue(res if res else False) - @cachedInlineCallbacks() def is_guest(self, user_id): res = yield self._simple_select_one_onecol( @@ -352,22 +366,6 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): defer.returnValue(res if res else False) - def _query_for_auth(self, txn, token): - sql = ( - "SELECT users.name, users.is_guest, access_tokens.id as token_id," - " access_tokens.device_id" - " FROM users" - " INNER JOIN access_tokens on users.name = access_tokens.user_id" - " WHERE token = ?" - ) - - txn.execute(sql, (token,)) - rows = self.cursor_to_dict(txn) - if rows: - return rows[0] - - return None - @defer.inlineCallbacks def user_add_threepid(self, user_id, medium, address, validated_at, added_at): yield self._simple_upsert("user_threepids", { @@ -438,6 +436,19 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): defer.returnValue(ret) @defer.inlineCallbacks + def count_nonbridged_users(self): + def _count_users(txn): + txn.execute(""" + SELECT COALESCE(COUNT(*), 0) FROM users + WHERE appservice_id IS NULL + """) + count, = txn.fetchone() + return count + + ret = yield self.runInteraction("count_users", _count_users) + defer.returnValue(ret) + + @defer.inlineCallbacks def find_next_generated_user_id_localpart(self): """ Gets the localpart of the next generated user ID. @@ -451,18 +462,16 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): """ def _find_next_generated_user_id(txn): txn.execute("SELECT name FROM users") - rows = self.cursor_to_dict(txn) regex = re.compile("^@(\d+):") found = set() - for r in rows: - user_id = r["name"] + for user_id, in txn: match = regex.search(user_id) if match: found.add(int(match.group(1))) - for i in xrange(len(found) + 1): + for i in range(len(found) + 1): if i not in found: return i diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 8a2fe2fdf5..ea6a189185 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -16,14 +16,14 @@ from twisted.internet import defer from synapse.api.errors import StoreError -from synapse.util.caches.descriptors import cached - -from ._base import SQLBaseStore -from .engines import PostgresEngine, Sqlite3Engine +from synapse.storage._base import SQLBaseStore +from synapse.storage.search import SearchStore +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks import collections import logging -import ujson as json +import simplejson as json +import re logger = logging.getLogger(__name__) @@ -33,8 +33,144 @@ OpsLevel = collections.namedtuple( ("ban_level", "kick_level", "redact_level",) ) +RatelimitOverride = collections.namedtuple( + "RatelimitOverride", + ("messages_per_second", "burst_count",) +) + -class RoomStore(SQLBaseStore): +class RoomWorkerStore(SQLBaseStore): + def get_public_room_ids(self): + return self._simple_select_onecol( + table="rooms", + keyvalues={ + "is_public": True, + }, + retcol="room_id", + desc="get_public_room_ids", + ) + + @cached(num_args=2, max_entries=100) + def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): + """Get pulbic rooms for a particular list, or across all lists. + + Args: + stream_id (int) + network_tuple (ThirdPartyInstanceID): The list to use (None, None) + means the main list, None means all lsits. + """ + return self.runInteraction( + "get_public_room_ids_at_stream_id", + self.get_public_room_ids_at_stream_id_txn, + stream_id, network_tuple=network_tuple + ) + + def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, + network_tuple): + return { + rm + for rm, vis in self.get_published_at_stream_id_txn( + txn, stream_id, network_tuple=network_tuple + ).items() + if vis + } + + def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): + if network_tuple: + # We want to get from a particular list. No aggregation required. + + sql = (""" + SELECT room_id, visibility FROM public_room_list_stream + INNER JOIN ( + SELECT room_id, max(stream_id) AS stream_id + FROM public_room_list_stream + WHERE stream_id <= ? %s + GROUP BY room_id + ) grouped USING (room_id, stream_id) + """) + + if network_tuple.appservice_id is not None: + txn.execute( + sql % ("AND appservice_id = ? AND network_id = ?",), + (stream_id, network_tuple.appservice_id, network_tuple.network_id,) + ) + else: + txn.execute( + sql % ("AND appservice_id IS NULL",), + (stream_id,) + ) + return dict(txn) + else: + # We want to get from all lists, so we need to aggregate the results + + logger.info("Executing full list") + + sql = (""" + SELECT room_id, visibility + FROM public_room_list_stream + INNER JOIN ( + SELECT + room_id, max(stream_id) AS stream_id, appservice_id, + network_id + FROM public_room_list_stream + WHERE stream_id <= ? + GROUP BY room_id, appservice_id, network_id + ) grouped USING (room_id, stream_id) + """) + + txn.execute( + sql, + (stream_id,) + ) + + results = {} + # A room is visible if its visible on any list. + for room_id, visibility in txn: + results[room_id] = bool(visibility) or results.get(room_id, False) + + return results + + def get_public_room_changes(self, prev_stream_id, new_stream_id, + network_tuple): + def get_public_room_changes_txn(txn): + then_rooms = self.get_public_room_ids_at_stream_id_txn( + txn, prev_stream_id, network_tuple + ) + + now_rooms_dict = self.get_published_at_stream_id_txn( + txn, new_stream_id, network_tuple + ) + + now_rooms_visible = set( + rm for rm, vis in now_rooms_dict.items() if vis + ) + now_rooms_not_visible = set( + rm for rm, vis in now_rooms_dict.items() if not vis + ) + + newly_visible = now_rooms_visible - then_rooms + newly_unpublished = now_rooms_not_visible & then_rooms + + return newly_visible, newly_unpublished + + return self.runInteraction( + "get_public_room_changes", get_public_room_changes_txn + ) + + @cached(max_entries=10000) + def is_room_blocked(self, room_id): + return self._simple_select_one_onecol( + table="blocked_rooms", + keyvalues={ + "room_id": room_id, + }, + retcol="1", + allow_none=True, + desc="is_room_blocked", + ) + + +class RoomStore(RoomWorkerStore, SearchStore): @defer.inlineCallbacks def store_room(self, room_id, room_creator_user_id, is_public): @@ -221,16 +357,6 @@ class RoomStore(SQLBaseStore): ) self.hs.get_notifier().on_new_replication_data() - def get_public_room_ids(self): - return self._simple_select_onecol( - table="rooms", - keyvalues={ - "is_public": True, - }, - retcol="room_id", - desc="get_public_room_ids", - ) - def get_room_count(self): """Retrieve a list of all rooms """ @@ -257,8 +383,8 @@ class RoomStore(SQLBaseStore): }, ) - self._store_event_search_txn( - txn, event, "content.topic", event.content["topic"] + self.store_event_search_txn( + txn, event, "content.topic", event.content["topic"], ) def _store_room_name_txn(self, txn, event): @@ -273,14 +399,14 @@ class RoomStore(SQLBaseStore): } ) - self._store_event_search_txn( - txn, event, "content.name", event.content["name"] + self.store_event_search_txn( + txn, event, "content.name", event.content["name"], ) def _store_room_message_txn(self, txn, event): if hasattr(event, "content") and "body" in event.content: - self._store_event_search_txn( - txn, event, "content.body", event.content["body"] + self.store_event_search_txn( + txn, event, "content.body", event.content["body"], ) def _store_history_visibility_txn(self, txn, event): @@ -302,31 +428,6 @@ class RoomStore(SQLBaseStore): event.content[key] )) - def _store_event_search_txn(self, txn, event, key, value): - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search" - " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" - " VALUES (?,?,?,to_tsvector('english', ?),?,?)" - ) - txn.execute( - sql, - ( - event.event_id, event.room_id, key, value, - event.internal_metadata.stream_ordering, - event.origin_server_ts, - ) - ) - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - txn.execute(sql, (event.event_id, event.room_id, key, value,)) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - def add_event_report(self, room_id, event_id, user_id, reason, content, received_ts): next_id = self._event_reports_id_gen.get_next() @@ -347,129 +448,180 @@ class RoomStore(SQLBaseStore): def get_current_public_room_stream_id(self): return self._public_room_id_gen.get_current_token() - @cached(num_args=2, max_entries=100) - def get_public_room_ids_at_stream_id(self, stream_id, network_tuple): - """Get pulbic rooms for a particular list, or across all lists. + def get_all_new_public_rooms(self, prev_id, current_id, limit): + def get_all_new_public_rooms(txn): + sql = (""" + SELECT stream_id, room_id, visibility, appservice_id, network_id + FROM public_room_list_stream + WHERE stream_id > ? AND stream_id <= ? + ORDER BY stream_id ASC + LIMIT ? + """) + + txn.execute(sql, (prev_id, current_id, limit,)) + return txn.fetchall() + + if prev_id == current_id: + return defer.succeed([]) - Args: - stream_id (int) - network_tuple (ThirdPartyInstanceID): The list to use (None, None) - means the main list, None means all lsits. - """ return self.runInteraction( - "get_public_room_ids_at_stream_id", - self.get_public_room_ids_at_stream_id_txn, - stream_id, network_tuple=network_tuple + "get_all_new_public_rooms", get_all_new_public_rooms ) - def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, - network_tuple): - return { - rm - for rm, vis in self.get_published_at_stream_id_txn( - txn, stream_id, network_tuple=network_tuple - ).items() - if vis - } + @cachedInlineCallbacks(max_entries=10000) + def get_ratelimit_for_user(self, user_id): + """Check if there are any overrides for ratelimiting for the given + user - def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple): - if network_tuple: - # We want to get from a particular list. No aggregation required. + Args: + user_id (str) - sql = (""" - SELECT room_id, visibility FROM public_room_list_stream - INNER JOIN ( - SELECT room_id, max(stream_id) AS stream_id - FROM public_room_list_stream - WHERE stream_id <= ? %s - GROUP BY room_id - ) grouped USING (room_id, stream_id) - """) + Returns: + RatelimitOverride if there is an override, else None. If the contents + of RatelimitOverride are None or 0 then ratelimitng has been + disabled for that user entirely. + """ + row = yield self._simple_select_one( + table="ratelimit_override", + keyvalues={"user_id": user_id}, + retcols=("messages_per_second", "burst_count"), + allow_none=True, + desc="get_ratelimit_for_user", + ) - if network_tuple.appservice_id is not None: - txn.execute( - sql % ("AND appservice_id = ? AND network_id = ?",), - (stream_id, network_tuple.appservice_id, network_tuple.network_id,) - ) - else: - txn.execute( - sql % ("AND appservice_id IS NULL",), - (stream_id,) - ) - return dict(txn.fetchall()) + if row: + defer.returnValue(RatelimitOverride( + messages_per_second=row["messages_per_second"], + burst_count=row["burst_count"], + )) else: - # We want to get from all lists, so we need to aggregate the results + defer.returnValue(None) - logger.info("Executing full list") - - sql = (""" - SELECT room_id, visibility - FROM public_room_list_stream - INNER JOIN ( - SELECT - room_id, max(stream_id) AS stream_id, appservice_id, - network_id - FROM public_room_list_stream - WHERE stream_id <= ? - GROUP BY room_id, appservice_id, network_id - ) grouped USING (room_id, stream_id) - """) - - txn.execute( - sql, - (stream_id,) - ) - - results = {} - # A room is visible if its visible on any list. - for room_id, visibility in txn.fetchall(): - results[room_id] = bool(visibility) or results.get(room_id, False) - - return results + @defer.inlineCallbacks + def block_room(self, room_id, user_id): + yield self._simple_insert( + table="blocked_rooms", + values={ + "room_id": room_id, + "user_id": user_id, + }, + desc="block_room", + ) + yield self.runInteraction( + "block_room_invalidation", + self._invalidate_cache_and_stream, + self.is_room_blocked, (room_id,), + ) - def get_public_room_changes(self, prev_stream_id, new_stream_id, - network_tuple): - def get_public_room_changes_txn(txn): - then_rooms = self.get_public_room_ids_at_stream_id_txn( - txn, prev_stream_id, network_tuple - ) + def get_media_mxcs_in_room(self, room_id): + """Retrieves all the local and remote media MXC URIs in a given room - now_rooms_dict = self.get_published_at_stream_id_txn( - txn, new_stream_id, network_tuple - ) + Args: + room_id (str) - now_rooms_visible = set( - rm for rm, vis in now_rooms_dict.items() if vis - ) - now_rooms_not_visible = set( - rm for rm, vis in now_rooms_dict.items() if not vis + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + def _get_media_mxcs_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + local_media_mxcs = [] + remote_media_mxcs = [] + + # Convert the IDs to MXC URIs + for media_id in local_mxcs: + local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id)) + for hostname, media_id in remote_mxcs: + remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs + return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn) + + def quarantine_media_ids_in_room(self, room_id, quarantined_by): + """For a room loops through all events with media and quarantines + the associated media + """ + def _quarantine_media_in_room_txn(txn): + local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id) + total_media_quarantined = 0 + + # Now update all the tables to set the quarantined_by flag + + txn.executemany(""" + UPDATE local_media_repository + SET quarantined_by = ? + WHERE media_id = ? + """, ((quarantined_by, media_id) for media_id in local_mxcs)) + + txn.executemany( + """ + UPDATE remote_media_cache + SET quarantined_by = ? + WHERE media_origin = ? AND media_id = ? + """, + ( + (quarantined_by, origin, media_id) + for origin, media_id in remote_mxcs + ) ) - newly_visible = now_rooms_visible - then_rooms - newly_unpublished = now_rooms_not_visible & then_rooms + total_media_quarantined += len(local_mxcs) + total_media_quarantined += len(remote_mxcs) - return newly_visible, newly_unpublished + return total_media_quarantined return self.runInteraction( - "get_public_room_changes", get_public_room_changes_txn + "quarantine_media_in_room", + _quarantine_media_in_room_txn, ) - def get_all_new_public_rooms(self, prev_id, current_id, limit): - def get_all_new_public_rooms(txn): - sql = (""" - SELECT stream_id, room_id, visibility, appservice_id, network_id - FROM public_room_list_stream - WHERE stream_id > ? AND stream_id <= ? - ORDER BY stream_id ASC - LIMIT ? - """) - - txn.execute(sql, (prev_id, current_id, limit,)) - return txn.fetchall() + def _get_media_mxcs_in_room_txn(self, txn, room_id): + """Retrieves all the local and remote media MXC URIs in a given room - if prev_id == current_id: - return defer.succeed([]) + Args: + txn (cursor) + room_id (str) - return self.runInteraction( - "get_all_new_public_rooms", get_all_new_public_rooms - ) + Returns: + The local and remote media as a lists of tuples where the key is + the hostname and the value is the media ID. + """ + mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") + + next_token = self.get_current_events_token() + 1 + local_media_mxcs = [] + remote_media_mxcs = [] + + while next_token: + sql = """ + SELECT stream_ordering, json FROM events + JOIN event_json USING (room_id, event_id) + WHERE room_id = ? + AND stream_ordering < ? + AND contains_url = ? AND outlier = ? + ORDER BY stream_ordering DESC + LIMIT ? + """ + txn.execute(sql, (room_id, next_token, True, False, 100)) + + next_token = None + for stream_ordering, content_json in txn: + next_token = stream_ordering + event_json = json.loads(content_json) + content = event_json["content"] + content_url = content.get("url") + thumbnail_url = content.get("info", {}).get("thumbnail_url") + + for url in (content_url, thumbnail_url): + if not url: + continue + matches = mxc_re.match(url) + if matches: + hostname = matches.group(1) + media_id = matches.group(2) + if hostname == self.hs.hostname: + local_media_mxcs.append(media_id) + else: + remote_media_mxcs.append((hostname, media_id)) + + return local_media_mxcs, remote_media_mxcs diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 545d3d3a99..6a861943a2 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,14 +18,17 @@ from twisted.internet import defer from collections import namedtuple -from ._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore +from synapse.util.async import Linearizer +from synapse.util.caches import intern_string from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.stringutils import to_ascii from synapse.api.constants import Membership, EventTypes from synapse.types import get_domain_from_id import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) @@ -34,112 +38,47 @@ RoomsForUser = namedtuple( ("room_id", "sender", "membership", "event_id", "stream_ordering") ) +GetRoomsForUserWithStreamOrdering = namedtuple( + "_GetRoomsForUserWithStreamOrdering", + ("room_id", "stream_ordering",) +) -_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" - - -class RoomMemberStore(SQLBaseStore): - def __init__(self, hs): - super(RoomMemberStore, self).__init__(hs) - self.register_background_update_handler( - _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile - ) - def _store_room_members_txn(self, txn, events, backfilled): - """Store a room member in the database. - """ - self._simple_insert_many_txn( - txn, - table="room_memberships", - values=[ - { - "event_id": event.event_id, - "user_id": event.state_key, - "sender": event.user_id, - "room_id": event.room_id, - "membership": event.membership, - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - for event in events - ] - ) +# We store this using a namedtuple so that we save about 3x space over using a +# dict. +ProfileInfo = namedtuple( + "ProfileInfo", ("avatar_url", "display_name") +) - for event in events: - txn.call_after( - self._membership_stream_cache.entity_has_changed, - event.state_key, event.internal_metadata.stream_ordering - ) - txn.call_after( - self.get_invited_rooms_for_user.invalidate, (event.state_key,) - ) - # We update the local_invites table only if the event is "current", - # i.e., its something that has just happened. - # The only current event that can also be an outlier is if its an - # invite that has come in across federation. - is_new_state = not backfilled and ( - not event.internal_metadata.is_outlier() - or event.internal_metadata.is_invite_from_remote() - ) - is_mine = self.hs.is_mine_id(event.state_key) - if is_new_state and is_mine: - if event.membership == Membership.INVITE: - self._simple_insert_txn( - txn, - table="local_invites", - values={ - "event_id": event.event_id, - "invitee": event.state_key, - "inviter": event.sender, - "room_id": event.room_id, - "stream_id": event.internal_metadata.stream_ordering, - } - ) - else: - sql = ( - "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" - ) +_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" - txn.execute(sql, ( - event.internal_metadata.stream_ordering, - event.event_id, - event.room_id, - event.state_key, - )) - @defer.inlineCallbacks - def locally_reject_invite(self, user_id, room_id): - sql = ( - "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" - " room_id = ? AND invitee = ? AND locally_rejected is NULL" - " AND replaced_by is NULL" +class RoomMemberWorkerStore(EventsWorkerStore): + @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) + def get_hosts_in_room(self, room_id, cache_context): + """Returns the set of all hosts currently in the room + """ + user_ids = yield self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate, ) + hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) + defer.returnValue(hosts) - def f(txn, stream_ordering): - txn.execute(sql, ( - stream_ordering, - True, - room_id, - user_id, - )) - - with self._stream_id_gen.get_next() as stream_ordering: - yield self.runInteraction("locally_reject_invite", f, stream_ordering) - - @cached(max_entries=500000, iterable=True) + @cached(max_entries=100000, iterable=True) def get_users_in_room(self, room_id): def f(txn): - - rows = self._get_members_rows_txn( - txn, - room_id=room_id, - membership=Membership.JOIN, + sql = ( + "SELECT m.user_id FROM room_memberships as m" + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id " + " AND m.room_id = c.room_id " + " AND m.user_id = c.state_key" + " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" ) - return [r["user_id"] for r in rows] + txn.execute(sql, (room_id, Membership.JOIN,)) + return [to_ascii(r[0]) for r in txn] return self.runInteraction("get_users_in_room", f) @cached() @@ -246,57 +185,382 @@ class RoomMemberStore(SQLBaseStore): return results - def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): - where_clause = "c.room_id = ?" - where_values = [room_id] - - if membership: - where_clause += " AND m.membership = ?" - where_values.append(membership) + @cachedInlineCallbacks(max_entries=500000, iterable=True) + def get_rooms_for_user_with_stream_ordering(self, user_id): + """Returns a set of room_ids the user is currently joined to - if user_id: - where_clause += " AND m.user_id = ?" - where_values.append(user_id) + Args: + user_id (str) - sql = ( - "SELECT m.* FROM room_memberships as m" - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id " - " AND m.room_id = c.room_id " - " AND m.user_id = c.state_key" - " WHERE c.type = 'm.room.member' AND %(where)s" - ) % { - "where": where_clause, - } - - txn.execute(sql, where_values) - rows = self.cursor_to_dict(txn) - - return rows - - @cached(max_entries=500000, iterable=True) - def get_rooms_for_user(self, user_id): - return self.get_rooms_for_user_where_membership_is( + Returns: + Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns + the rooms the user is in currently, along with the stream ordering + of the most recent join for that user and room. + """ + rooms = yield self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN], ) + defer.returnValue(frozenset( + GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering) + for r in rooms + )) + + @defer.inlineCallbacks + def get_rooms_for_user(self, user_id, on_invalidate=None): + """Returns a set of room_ids the user is currently joined to + """ + rooms = yield self.get_rooms_for_user_with_stream_ordering( + user_id, on_invalidate=on_invalidate, + ) + defer.returnValue(frozenset(r.room_id for r in rooms)) @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) def get_users_who_share_room_with_user(self, user_id, cache_context): """Returns the set of users who share a room with `user_id` """ - rooms = yield self.get_rooms_for_user( + room_ids = yield self.get_rooms_for_user( user_id, on_invalidate=cache_context.invalidate, ) user_who_share_room = set() - for room in rooms: + for room_id in room_ids: user_ids = yield self.get_users_in_room( - room.room_id, on_invalidate=cache_context.invalidate, + room_id, on_invalidate=cache_context.invalidate, ) user_who_share_room.update(user_ids) defer.returnValue(user_who_share_room) + def get_joined_users_from_context(self, event, context): + state_group = context.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + event.room_id, state_group, context.current_state_ids, + event=event, + context=context, + ) + + def get_joined_users_from_state(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry, + ) + + @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, + max_entries=100000) + def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, + cache_context, event=None, context=None): + # We don't use `state_group`, it's there so that we can cache based + # on it. However, it's important that it's never None, since two current_states + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + users_in_room = {} + member_event_ids = [ + e_id + for key, e_id in current_state_ids.iteritems() + if key[0] == EventTypes.Member + ] + + if context is not None: + # If we have a context with a delta from a previous state group, + # check if we also have the result from the previous group in cache. + # If we do then we can reuse that result and simply update it with + # any membership changes in `delta_ids` + if context.prev_group and context.delta_ids: + prev_res = self._get_joined_users_from_context.cache.get( + (room_id, context.prev_group), None + ) + if prev_res and isinstance(prev_res, dict): + users_in_room = dict(prev_res) + member_event_ids = [ + e_id + for key, e_id in context.delta_ids.iteritems() + if key[0] == EventTypes.Member + ] + for etype, state_key in context.delta_ids: + users_in_room.pop(state_key, None) + + # We check if we have any of the member event ids in the event cache + # before we ask the DB + + # We don't update the event cache hit ratio as it completely throws off + # the hit ratio counts. After all, we don't populate the cache if we + # miss it here + event_map = self._get_events_from_cache( + member_event_ids, + allow_rejected=False, + update_metrics=False, + ) + + missing_member_event_ids = [] + for event_id in member_event_ids: + ev_entry = event_map.get(event_id) + if ev_entry: + if ev_entry.event.membership == Membership.JOIN: + users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo( + display_name=to_ascii( + ev_entry.event.content.get("displayname", None) + ), + avatar_url=to_ascii( + ev_entry.event.content.get("avatar_url", None) + ), + ) + else: + missing_member_event_ids.append(event_id) + + if missing_member_event_ids: + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=missing_member_event_ids, + retcols=('user_id', 'display_name', 'avatar_url',), + keyvalues={ + "membership": Membership.JOIN, + }, + batch_size=500, + desc="_get_joined_users_from_context", + ) + + users_in_room.update({ + to_ascii(row["user_id"]): ProfileInfo( + avatar_url=to_ascii(row["avatar_url"]), + display_name=to_ascii(row["display_name"]), + ) + for row in rows + }) + + if event is not None and event.type == EventTypes.Member: + if event.membership == Membership.JOIN: + if event.event_id in member_event_ids: + users_in_room[to_ascii(event.state_key)] = ProfileInfo( + display_name=to_ascii(event.content.get("displayname", None)), + avatar_url=to_ascii(event.content.get("avatar_url", None)), + ) + + defer.returnValue(users_in_room) + + @cachedInlineCallbacks(max_entries=10000) + def is_host_joined(self, room_id, host): + if '%' in host or '_' in host: + raise Exception("Invalid host name") + + sql = """ + SELECT state_key FROM current_state_events AS c + INNER JOIN room_memberships USING (event_id) + WHERE membership = 'join' + AND type = 'm.room.member' + AND c.room_id = ? + AND state_key LIKE ? + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause) + + if not rows: + defer.returnValue(False) + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + defer.returnValue(True) + + @cachedInlineCallbacks() + def was_host_joined(self, room_id, host): + """Check whether the server is or ever was in the room. + + Args: + room_id (str) + host (str) + + Returns: + Deferred: Resolves to True if the host is/was in the room, otherwise + False. + """ + if '%' in host or '_' in host: + raise Exception("Invalid host name") + + sql = """ + SELECT user_id FROM room_memberships + WHERE room_id = ? + AND user_id LIKE ? + AND membership = 'join' + LIMIT 1 + """ + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + like_clause = "%:" + host + + rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause) + + if not rows: + defer.returnValue(False) + + user_id = rows[0][0] + if get_domain_from_id(user_id) != host: + # This can only happen if the host name has something funky in it + raise Exception("Invalid host name") + + defer.returnValue(True) + + def get_joined_hosts(self, room_id, state_entry): + state_group = state_entry.state_group + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry, + ) + + @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) + # @defer.inlineCallbacks + def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + cache = self._get_joined_hosts_cache(room_id) + joined_hosts = yield cache.get_destinations(state_entry) + + defer.returnValue(joined_hosts) + + @cached(max_entries=10000, iterable=True) + def _get_joined_hosts_cache(self, room_id): + return _JoinedHostsCache(self, room_id) + + @cached() + def who_forgot_in_room(self, room_id): + return self._simple_select_list( + table="room_memberships", + retcols=("user_id", "event_id"), + keyvalues={ + "room_id": room_id, + "forgotten": 1, + }, + desc="who_forgot" + ) + + +class RoomMemberStore(RoomMemberWorkerStore): + def __init__(self, db_conn, hs): + super(RoomMemberStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile + ) + + def _store_room_members_txn(self, txn, events, backfilled): + """Store a room member in the database. + """ + self._simple_insert_many_txn( + txn, + table="room_memberships", + values=[ + { + "event_id": event.event_id, + "user_id": event.state_key, + "sender": event.user_id, + "room_id": event.room_id, + "membership": event.membership, + "display_name": event.content.get("displayname", None), + "avatar_url": event.content.get("avatar_url", None), + } + for event in events + ] + ) + + for event in events: + txn.call_after( + self._membership_stream_cache.entity_has_changed, + event.state_key, event.internal_metadata.stream_ordering + ) + txn.call_after( + self.get_invited_rooms_for_user.invalidate, (event.state_key,) + ) + + # We update the local_invites table only if the event is "current", + # i.e., its something that has just happened. + # The only current event that can also be an outlier is if its an + # invite that has come in across federation. + is_new_state = not backfilled and ( + not event.internal_metadata.is_outlier() + or event.internal_metadata.is_invite_from_remote() + ) + is_mine = self.hs.is_mine_id(event.state_key) + if is_new_state and is_mine: + if event.membership == Membership.INVITE: + self._simple_insert_txn( + txn, + table="local_invites", + values={ + "event_id": event.event_id, + "invitee": event.state_key, + "inviter": event.sender, + "room_id": event.room_id, + "stream_id": event.internal_metadata.stream_ordering, + } + ) + else: + sql = ( + "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + txn.execute(sql, ( + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.state_key, + )) + + @defer.inlineCallbacks + def locally_reject_invite(self, user_id, room_id): + sql = ( + "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE" + " room_id = ? AND invitee = ? AND locally_rejected is NULL" + " AND replaced_by is NULL" + ) + + def f(txn, stream_ordering): + txn.execute(sql, ( + stream_ordering, + True, + room_id, + user_id, + )) + + with self._stream_id_gen.get_next() as stream_ordering: + yield self.runInteraction("locally_reject_invite", f, stream_ordering) + def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" def f(txn): @@ -368,124 +632,6 @@ class RoomMemberStore(SQLBaseStore): forgot = yield self.runInteraction("did_forget_membership_at", f) defer.returnValue(forgot == 1) - @cached() - def who_forgot_in_room(self, room_id): - return self._simple_select_list( - table="room_memberships", - retcols=("user_id", "event_id"), - keyvalues={ - "room_id": room_id, - "forgotten": 1, - }, - desc="who_forgot" - ) - - def get_joined_users_from_context(self, event, context): - state_group = context.state_group - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - return self._get_joined_users_from_context( - event.room_id, state_group, context.current_state_ids, event=event, - ) - - def get_joined_users_from_state(self, room_id, state_group, state_ids): - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - return self._get_joined_users_from_context( - room_id, state_group, state_ids, - ) - - @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True, - max_entries=100000) - def _get_joined_users_from_context(self, room_id, state_group, current_state_ids, - cache_context, event=None): - # We don't use `state_group`, it's there so that we can cache based - # on it. However, it's important that it's never None, since two current_states - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - member_event_ids = [ - e_id - for key, e_id in current_state_ids.iteritems() - if key[0] == EventTypes.Member - ] - - rows = yield self._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=member_event_ids, - retcols=['user_id', 'display_name', 'avatar_url'], - keyvalues={ - "membership": Membership.JOIN, - }, - batch_size=500, - desc="_get_joined_users_from_context", - ) - - users_in_room = { - row["user_id"]: { - "display_name": row["display_name"], - "avatar_url": row["avatar_url"], - } - for row in rows - } - - if event is not None and event.type == EventTypes.Member: - if event.membership == Membership.JOIN: - if event.event_id in member_event_ids: - users_in_room[event.state_key] = { - "display_name": event.content.get("displayname", None), - "avatar_url": event.content.get("avatar_url", None), - } - - defer.returnValue(users_in_room) - - def is_host_joined(self, room_id, host, state_group, state_ids): - if not state_group: - # If state_group is None it means it has yet to be assigned a - # state group, i.e. we need to make sure that calls with a state_group - # of None don't hit previous cached calls with a None state_group. - # To do this we set the state_group to a new object as object() != object() - state_group = object() - - return self._is_host_joined( - room_id, host, state_group, state_ids - ) - - @cachedInlineCallbacks(num_args=3) - def _is_host_joined(self, room_id, host, state_group, current_state_ids): - # We don't use `state_group`, its there so that we can cache based - # on it. However, its important that its never None, since two current_state's - # with a state_group of None are likely to be different. - # See bulk_get_push_rules_for_room for how we work around this. - assert state_group is not None - - for (etype, state_key), event_id in current_state_ids.items(): - if etype == EventTypes.Member: - try: - if get_domain_from_id(state_key) != host: - continue - except: - logger.warn("state_key not user_id: %s", state_key) - continue - - event = yield self.get_event(event_id, allow_none=True) - if event and event.content["membership"] == Membership.JOIN: - defer.returnValue(True) - - defer.returnValue(False) - @defer.inlineCallbacks def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( @@ -499,8 +645,9 @@ class RoomMemberStore(SQLBaseStore): def add_membership_profile_txn(txn): sql = (""" - SELECT stream_ordering, event_id, events.room_id, content + SELECT stream_ordering, event_id, events.room_id, event_json.json FROM events + INNER JOIN event_json USING (event_id) INNER JOIN room_memberships USING (event_id) WHERE ? <= stream_ordering AND stream_ordering < ? AND type = 'm.room.member' @@ -521,8 +668,9 @@ class RoomMemberStore(SQLBaseStore): event_id = row["event_id"] room_id = row["room_id"] try: - content = json.loads(row["content"]) - except: + event_json = json.loads(row["json"]) + content = event_json['content'] + except Exception: continue display_name = content.get("displayname", None) @@ -560,3 +708,71 @@ class RoomMemberStore(SQLBaseStore): yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME) defer.returnValue(result) + + +class _JoinedHostsCache(object): + """Cache for joined hosts in a room that is optimised to handle updates + via state deltas. + """ + + def __init__(self, store, room_id): + self.store = store + self.room_id = room_id + + self.hosts_to_joined_users = {} + + self.state_group = object() + + self.linearizer = Linearizer("_JoinedHostsCache") + + self._len = 0 + + @defer.inlineCallbacks + def get_destinations(self, state_entry): + """Get set of destinations for a state entry + + Args: + state_entry(synapse.state._StateCacheEntry) + """ + if state_entry.state_group == self.state_group: + defer.returnValue(frozenset(self.hosts_to_joined_users)) + + with (yield self.linearizer.queue(())): + if state_entry.state_group == self.state_group: + pass + elif state_entry.prev_group == self.state_group: + for (typ, state_key), event_id in state_entry.delta_ids.iteritems(): + if typ != EventTypes.Member: + continue + + host = intern_string(get_domain_from_id(state_key)) + user_id = state_key + known_joins = self.hosts_to_joined_users.setdefault(host, set()) + + event = yield self.store.get_event(event_id) + if event.membership == Membership.JOIN: + known_joins.add(user_id) + else: + known_joins.discard(user_id) + + if not known_joins: + self.hosts_to_joined_users.pop(host, None) + else: + joined_users = yield self.store.get_joined_users_from_state( + self.room_id, state_entry, + ) + + self.hosts_to_joined_users = {} + for user_id in joined_users: + host = intern_string(get_domain_from_id(user_id)) + self.hosts_to_joined_users.setdefault(host, set()).add(user_id) + + if state_entry.state_group: + self.state_group = state_entry.state_group + else: + self.state_group = object() + self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues()) + defer.returnValue(frozenset(self.hosts_to_joined_users)) + + def __len__(self): + return self._len diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py index 8755bb2e49..4d725b92fe 100644 --- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py +++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging +import simplejson as json + logger = logging.getLogger(__name__) diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py index 4269ac69ad..e7351c3ae6 100644 --- a/synapse/storage/schema/delta/25/fts.py +++ b/synapse/storage/schema/delta/25/fts.py @@ -17,7 +17,7 @@ import logging from synapse.storage.prepare_database import get_statements from synapse.storage.engines import PostgresEngine, Sqlite3Engine -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py index 71b12a2731..6df57b5206 100644 --- a/synapse/storage/schema/delta/27/ts.py +++ b/synapse/storage/schema/delta/27/ts.py @@ -16,7 +16,7 @@ import logging from synapse.storage.prepare_database import get_statements -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py index 5b7d8d1ab5..85bd1a2006 100644 --- a/synapse/storage/schema/delta/30/as_users.py +++ b/synapse/storage/schema/delta/30/as_users.py @@ -14,6 +14,8 @@ import logging from synapse.config.appservice import load_appservices +from six.moves import range + logger = logging.getLogger(__name__) @@ -22,7 +24,7 @@ def run_create(cur, database_engine, *args, **kwargs): # NULL indicates user was not registered by an appservice. try: cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT") - except: + except Exception: # Maybe we already added the column? Hope so... pass @@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): for as_id, user_ids in owned.items(): n = 100 - user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n)) + user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n)) for chunk in user_chunks: cur.execute( database_engine.convert_param_style( diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py index 470ae0c005..fe6b7d196d 100644 --- a/synapse/storage/schema/delta/31/search_update.py +++ b/synapse/storage/schema/delta/31/search_update.py @@ -16,7 +16,7 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.prepare_database import get_statements import logging -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -49,7 +49,7 @@ def run_create(cur, database_engine, *args, **kwargs): "rows_inserted": 0, "have_added_indexes": False, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py index 83066cccc9..1e002f9db2 100644 --- a/synapse/storage/schema/delta/33/event_fields.py +++ b/synapse/storage/schema/delta/33/event_fields.py @@ -15,7 +15,7 @@ from synapse.storage.prepare_database import get_statements import logging -import ujson +import simplejson logger = logging.getLogger(__name__) @@ -44,7 +44,7 @@ def run_create(cur, database_engine, *args, **kwargs): "max_stream_id_exclusive": max_stream_id + 1, "rows_inserted": 0, } - progress_json = ujson.dumps(progress) + progress_json = simplejson.dumps(progress) sql = ( "INSERT into background_updates (update_name, progress_json)" diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/schema/delta/37/remove_auth_idx.py index 784f3b348f..20ad8bd5a6 100644 --- a/synapse/storage/schema/delta/37/remove_auth_idx.py +++ b/synapse/storage/schema/delta/37/remove_auth_idx.py @@ -36,6 +36,10 @@ DROP INDEX IF EXISTS transactions_have_ref; -- and is used incredibly rarely. DROP INDEX IF EXISTS events_order_topo_stream_room; +-- an equivalent index to this actually gets re-created in delta 41, because it +-- turned out that deleting it wasn't a great plan :/. In any case, let's +-- delete it here, and delta 41 will create a new one with an added UNIQUE +-- constraint DROP INDEX IF EXISTS event_search_ev_idx; """ diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/schema/delta/38/postgres_fts_gist.sql index f090a7b75a..515e6b8e84 100644 --- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql +++ b/synapse/storage/schema/delta/38/postgres_fts_gist.sql @@ -13,5 +13,7 @@ * limitations under the License. */ - INSERT into background_updates (update_name, progress_json) - VALUES ('event_search_postgres_gist', '{}'); +-- We no longer do this given we back it out again in schema 47 + +-- INSERT into background_updates (update_name, progress_json) +-- VALUES ('event_search_postgres_gist', '{}'); diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/schema/delta/40/event_push_summary.sql new file mode 100644 index 0000000000..3918f0b794 --- /dev/null +++ b/synapse/storage/schema/delta/40/event_push_summary.sql @@ -0,0 +1,37 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Aggregate of old notification counts that have been deleted out of the +-- main event_push_actions table. This count does not include those that were +-- highlights, as they remain in the event_push_actions table. +CREATE TABLE event_push_summary ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL +); + +CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); + + +-- The stream ordering up to which we have aggregated the event_push_actions +-- table into event_push_summary +CREATE TABLE event_push_summary_stream_ordering ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_ordering BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/schema/delta/40/pushers.sql new file mode 100644 index 0000000000..054a223f14 --- /dev/null +++ b/synapse/storage/schema/delta/40/pushers.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag TEXT NOT NULL, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang TEXT, + data TEXT, + last_stream_ordering INTEGER, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) +); + +INSERT INTO pushers2 SELECT * FROM PUSHERS; + +DROP TABLE PUSHERS; + +ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/schema/delta/23/refresh_tokens.sql b/synapse/storage/schema/delta/41/device_list_stream_idx.sql index 34db0cf12b..b7bee8b692 100644 --- a/synapse/storage/schema/delta/23/refresh_tokens.sql +++ b/synapse/storage/schema/delta/41/device_list_stream_idx.sql @@ -1,4 +1,4 @@ -/* Copyright 2015, 2016 OpenMarket Ltd +/* Copyright 2017 Vector Creations Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,9 +13,5 @@ * limitations under the License. */ -CREATE TABLE IF NOT EXISTS refresh_tokens( - id INTEGER PRIMARY KEY, - token TEXT NOT NULL, - user_id TEXT NOT NULL, - UNIQUE (token) -); +INSERT into background_updates (update_name, progress_json) + VALUES ('device_lists_stream_idx', '{}'); diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/schema/delta/41/device_outbound_index.sql new file mode 100644 index 0000000000..62f0b9892b --- /dev/null +++ b/synapse/storage/schema/delta/41/device_outbound_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql new file mode 100644 index 0000000000..5d9cfecf36 --- /dev/null +++ b/synapse/storage/schema/delta/41/event_search_event_id_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_event_id_idx', '{}'); diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/schema/delta/41/ratelimit.sql new file mode 100644 index 0000000000..a194bf0238 --- /dev/null +++ b/synapse/storage/schema/delta/41/ratelimit.sql @@ -0,0 +1,22 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE ratelimit_override ( + user_id TEXT NOT NULL, + messages_per_second BIGINT, + burst_count BIGINT +); + +CREATE UNIQUE INDEX ratelimit_override_idx ON ratelimit_override(user_id); diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/schema/delta/42/current_state_delta.sql new file mode 100644 index 0000000000..d28851aff8 --- /dev/null +++ b/synapse/storage/schema/delta/42/current_state_delta.sql @@ -0,0 +1,26 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE current_state_delta_stream ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + event_id TEXT, -- Is null if the key was removed + prev_event_id TEXT -- Is null if the key was added +); + +CREATE INDEX current_state_delta_stream_idx ON current_state_delta_stream(stream_id); diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/schema/delta/42/device_list_last_id.sql new file mode 100644 index 0000000000..9ab8c14fa3 --- /dev/null +++ b/synapse/storage/schema/delta/42/device_list_last_id.sql @@ -0,0 +1,33 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- Table of last stream_id that we sent to destination for user_id. This is +-- used to fill out the `prev_id` fields of outbound device list updates. +CREATE TABLE device_lists_outbound_last_success ( + destination TEXT NOT NULL, + user_id TEXT NOT NULL, + stream_id BIGINT NOT NULL +); + +INSERT INTO device_lists_outbound_last_success + SELECT destination, user_id, coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE sent = (1 = 1) -- sqlite doesn't have inbuilt boolean values + GROUP BY destination, user_id; + +CREATE INDEX device_lists_outbound_last_success_idx ON device_lists_outbound_last_success( + destination, user_id, stream_id +); diff --git a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql b/synapse/storage/schema/delta/42/event_auth_state_only.sql index bb225dafbf..b8821ac759 100644 --- a/synapse/storage/schema/delta/33/refreshtoken_device_index.sql +++ b/synapse/storage/schema/delta/42/event_auth_state_only.sql @@ -1,4 +1,4 @@ -/* Copyright 2016 OpenMarket Ltd +/* Copyright 2017 Vector Creations Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,4 +14,4 @@ */ INSERT INTO background_updates (update_name, progress_json) VALUES - ('refresh_tokens_device_index', '{}'); + ('event_auth_state_only', '{}'); diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/schema/delta/42/user_dir.py new file mode 100644 index 0000000000..ea6a18196d --- /dev/null +++ b/synapse/storage/schema/delta/42/user_dir.py @@ -0,0 +1,84 @@ +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.storage.prepare_database import get_statements +from synapse.storage.engines import PostgresEngine, Sqlite3Engine + +logger = logging.getLogger(__name__) + + +BOTH_TABLES = """ +CREATE TABLE user_directory_stream_pos ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_id BIGINT, + CHECK (Lock='X') +); + +INSERT INTO user_directory_stream_pos (stream_id) VALUES (null); + +CREATE TABLE user_directory ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, -- A room_id that we know the user is joined to + display_name TEXT, + avatar_url TEXT +); + +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); + +CREATE TABLE users_in_pubic_room ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL -- A room_id that we know is public +); + +CREATE INDEX users_in_pubic_room_room_idx ON users_in_pubic_room(room_id); +CREATE UNIQUE INDEX users_in_pubic_room_user_idx ON users_in_pubic_room(user_id); +""" + + +POSTGRES_TABLE = """ +CREATE TABLE user_directory_search ( + user_id TEXT NOT NULL, + vector tsvector +); + +CREATE INDEX user_directory_search_fts_idx ON user_directory_search USING gin(vector); +CREATE UNIQUE INDEX user_directory_search_user_idx ON user_directory_search(user_id); +""" + + +SQLITE_TABLE = """ +CREATE VIRTUAL TABLE user_directory_search + USING fts4 ( user_id, value ); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + for statement in get_statements(BOTH_TABLES.splitlines()): + cur.execute(statement) + + if isinstance(database_engine, PostgresEngine): + for statement in get_statements(POSTGRES_TABLE.splitlines()): + cur.execute(statement) + elif isinstance(database_engine, Sqlite3Engine): + for statement in get_statements(SQLITE_TABLE.splitlines()): + cur.execute(statement) + else: + raise Exception("Unrecognized database engine") + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/schema/delta/43/blocked_rooms.sql new file mode 100644 index 0000000000..0e3cd143ff --- /dev/null +++ b/synapse/storage/schema/delta/43/blocked_rooms.sql @@ -0,0 +1,21 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE blocked_rooms ( + room_id TEXT NOT NULL, + user_id TEXT NOT NULL -- Admin who blocked the room +); + +CREATE UNIQUE INDEX blocked_rooms_idx ON blocked_rooms(room_id); diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/schema/delta/43/quarantine_media.sql new file mode 100644 index 0000000000..630907ec4f --- /dev/null +++ b/synapse/storage/schema/delta/43/quarantine_media.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN quarantined_by TEXT; +ALTER TABLE remote_media_cache ADD COLUMN quarantined_by TEXT; diff --git a/synapse/storage/schema/delta/33/refreshtoken_device.sql b/synapse/storage/schema/delta/43/url_cache.sql index 290bd6da86..45ebe020da 100644 --- a/synapse/storage/schema/delta/33/refreshtoken_device.sql +++ b/synapse/storage/schema/delta/43/url_cache.sql @@ -1,4 +1,4 @@ -/* Copyright 2016 OpenMarket Ltd +/* Copyright 2017 Vector Creations Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,4 +13,4 @@ * limitations under the License. */ -ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT; +ALTER TABLE local_media_repository ADD COLUMN url_cache TEXT; diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/schema/delta/43/user_share.sql new file mode 100644 index 0000000000..ee7062abe4 --- /dev/null +++ b/synapse/storage/schema/delta/43/user_share.sql @@ -0,0 +1,33 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Table keeping track of who shares a room with who. We only keep track +-- of this for local users, so `user_id` is local users only (but we do keep track +-- of which remote users share a room) +CREATE TABLE users_who_share_rooms ( + user_id TEXT NOT NULL, + other_user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + share_private BOOLEAN NOT NULL -- is the shared room private? i.e. they share a private room +); + + +CREATE UNIQUE INDEX users_who_share_rooms_u_idx ON users_who_share_rooms(user_id, other_user_id); +CREATE INDEX users_who_share_rooms_r_idx ON users_who_share_rooms(room_id); +CREATE INDEX users_who_share_rooms_o_idx ON users_who_share_rooms(other_user_id); + + +-- Make sure that we populate the table initially +UPDATE user_directory_stream_pos SET stream_id = NULL; diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql new file mode 100644 index 0000000000..b12f9b2ebf --- /dev/null +++ b/synapse/storage/schema/delta/44/expire_url_cache.sql @@ -0,0 +1,41 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this didn't work on SQLite 3.7 (because of lack of partial indexes), so was +-- removed and replaced with 46/local_media_repository_url_idx.sql. +-- +-- CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; + +-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support +-- indices on expressions until 3.9. +CREATE TABLE local_media_repository_url_cache_new( + url TEXT, + response_code INTEGER, + etag TEXT, + expires_ts BIGINT, + og TEXT, + media_id TEXT, + download_ts BIGINT +); + +INSERT INTO local_media_repository_url_cache_new + SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache; + +DROP TABLE local_media_repository_url_cache; +ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache; + +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/schema/delta/45/group_server.sql new file mode 100644 index 0000000000..b2333848a0 --- /dev/null +++ b/synapse/storage/schema/delta/45/group_server.sql @@ -0,0 +1,167 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT +); + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); + + +-- list of users the group server thinks are joined +CREATE TABLE group_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone +); + + +CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id); +CREATE INDEX groups_users_u_idx ON group_users(user_id); + +-- list of users the group server thinks are invited +CREATE TABLE group_invites ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL +); + +CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id); +CREATE INDEX groups_invites_u_idx ON group_invites(user_id); + + +CREATE TABLE group_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone +); + +CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id); +CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id); + + +-- Rooms to include in the summary +CREATE TABLE group_summary_rooms ( + group_id TEXT NOT NULL, + room_id TEXT NOT NULL, + category_id TEXT NOT NULL, + room_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone + UNIQUE (group_id, category_id, room_id, room_order), + CHECK (room_order > 0) +); + +CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id); + + +-- Categories to include in the summary +CREATE TABLE group_summary_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + cat_order BIGINT NOT NULL, + UNIQUE (group_id, category_id, cat_order), + CHECK (cat_order > 0) +); + +-- The categories in the group +CREATE TABLE group_room_categories ( + group_id TEXT NOT NULL, + category_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone + UNIQUE (group_id, category_id) +); + +-- The users to include in the group summary +CREATE TABLE group_summary_users ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + role_id TEXT NOT NULL, + user_order BIGINT NOT NULL, + is_public BOOLEAN NOT NULL -- whether the user should be show to everyone +); + +CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id); + +-- The roles to include in the group summary +CREATE TABLE group_summary_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + role_order BIGINT NOT NULL, + UNIQUE (group_id, role_id, role_order), + CHECK (role_order > 0) +); + + +-- The roles in a groups +CREATE TABLE group_roles ( + group_id TEXT NOT NULL, + role_id TEXT NOT NULL, + profile TEXT NOT NULL, + is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone + UNIQUE (group_id, role_id) +); + + +-- List of attestations we've given out and need to renew +CREATE TABLE group_attestations_renewals ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL +); + +CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id); +CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id); +CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms); + + +-- List of attestations we've received from remotes and are interested in. +CREATE TABLE group_attestations_remote ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + valid_until_ms BIGINT NOT NULL, + attestation_json TEXT NOT NULL +); + +CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id); +CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id); +CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms); + + +-- The group membership for the HS's users +CREATE TABLE local_group_membership ( + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + is_admin BOOLEAN NOT NULL, + membership TEXT NOT NULL, + is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership + content TEXT NOT NULL +); + +CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id); +CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id); + + +CREATE TABLE local_group_updates ( + stream_id BIGINT NOT NULL, + group_id TEXT NOT NULL, + user_id TEXT NOT NULL, + type TEXT NOT NULL, + content TEXT NOT NULL +); diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/schema/delta/45/profile_cache.sql new file mode 100644 index 0000000000..e5ddc84df0 --- /dev/null +++ b/synapse/storage/schema/delta/45/profile_cache.sql @@ -0,0 +1,28 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A subset of remote users whose profiles we have cached. +-- Whether a user is in this table or not is defined by the storage function +-- `is_subscribed_remote_profile_for_user` +CREATE TABLE remote_profile_cache ( + user_id TEXT NOT NULL, + displayname TEXT, + avatar_url TEXT, + last_check BIGINT NOT NULL +); + +CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id); +CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check); diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql new file mode 100644 index 0000000000..68c48a89a9 --- /dev/null +++ b/synapse/storage/schema/delta/46/drop_refresh_tokens.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* we no longer use (or create) the refresh_tokens table */ +DROP TABLE IF EXISTS refresh_tokens; diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql new file mode 100644 index 0000000000..bb307889c1 --- /dev/null +++ b/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- drop the unique constraint on deleted_pushers so that we can just insert +-- into it rather than upserting. + +CREATE TABLE deleted_pushers2 ( + stream_id BIGINT NOT NULL, + app_id TEXT NOT NULL, + pushkey TEXT NOT NULL, + user_id TEXT NOT NULL +); + +INSERT INTO deleted_pushers2 (stream_id, app_id, pushkey, user_id) + SELECT stream_id, app_id, pushkey, user_id from deleted_pushers; + +DROP TABLE deleted_pushers; +ALTER TABLE deleted_pushers2 RENAME TO deleted_pushers; + +-- create the index after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX deleted_pushers_stream_id ON deleted_pushers (stream_id); + diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/schema/delta/46/group_server.sql new file mode 100644 index 0000000000..097679bc9a --- /dev/null +++ b/synapse/storage/schema/delta/46/group_server.sql @@ -0,0 +1,32 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE groups_new ( + group_id TEXT NOT NULL, + name TEXT, -- the display name of the room + avatar_url TEXT, + short_description TEXT, + long_description TEXT, + is_public BOOL NOT NULL -- whether non-members can access group APIs +); + +-- NB: awful hack to get the default to be true on postgres and 1 on sqlite +INSERT INTO groups_new + SELECT group_id, name, avatar_url, short_description, long_description, (1=1) FROM groups; + +DROP TABLE groups; +ALTER TABLE groups_new RENAME TO groups; + +CREATE UNIQUE INDEX groups_idx ON groups(group_id); diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql new file mode 100644 index 0000000000..bbfc7f5d1a --- /dev/null +++ b/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- register a background update which will recreate the +-- local_media_repository_url_idx index. +-- +-- We do this as a bg update not because it is a particularly onerous +-- operation, but because we'd like it to be a partial index if possible, and +-- the background_index_update code will understand whether we are on +-- postgres or sqlite and behave accordingly. +INSERT INTO background_updates (update_name, progress_json) VALUES + ('local_media_repository_url_idx', '{}'); diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql new file mode 100644 index 0000000000..cb0d5a2576 --- /dev/null +++ b/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql @@ -0,0 +1,35 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- change the user_directory table to also cover global local user profiles +-- rather than just profiles within specific rooms. + +CREATE TABLE user_directory2 ( + user_id TEXT NOT NULL, + room_id TEXT, + display_name TEXT, + avatar_url TEXT +); + +INSERT INTO user_directory2(user_id, room_id, display_name, avatar_url) + SELECT user_id, room_id, display_name, avatar_url from user_directory; + +DROP TABLE user_directory; +ALTER TABLE user_directory2 RENAME TO user_directory; + +-- create indexes after doing the inserts because that's more efficient. +-- it also means we can give it the same name as the old one without renaming. +CREATE INDEX user_directory_room_idx ON user_directory(room_id); +CREATE UNIQUE INDEX user_directory_user_idx ON user_directory(user_id); diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/schema/delta/46/user_dir_typos.sql new file mode 100644 index 0000000000..d9505f8da1 --- /dev/null +++ b/synapse/storage/schema/delta/46/user_dir_typos.sql @@ -0,0 +1,24 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- this is just embarassing :| +ALTER TABLE users_in_pubic_room RENAME TO users_in_public_rooms; + +-- this is only 300K rows on matrix.org and takes ~3s to generate the index, +-- so is hopefully not going to block anyone else for that long... +CREATE INDEX users_in_public_rooms_room_idx ON users_in_public_rooms(room_id); +CREATE UNIQUE INDEX users_in_public_rooms_user_idx ON users_in_public_rooms(user_id); +DROP INDEX users_in_pubic_room_room_idx; +DROP INDEX users_in_pubic_room_user_idx; diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/schema/delta/47/last_access_media.sql new file mode 100644 index 0000000000..f505fb22b5 --- /dev/null +++ b/synapse/storage/schema/delta/47/last_access_media.sql @@ -0,0 +1,16 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE local_media_repository ADD COLUMN last_access_ts BIGINT; diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/schema/delta/47/postgres_fts_gin.sql new file mode 100644 index 0000000000..31d7a817eb --- /dev/null +++ b/synapse/storage/schema/delta/47/postgres_fts_gin.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('event_search_postgres_gin', '{}'); diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/schema/delta/47/push_actions_staging.sql new file mode 100644 index 0000000000..edccf4a96f --- /dev/null +++ b/synapse/storage/schema/delta/47/push_actions_staging.sql @@ -0,0 +1,28 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Temporary staging area for push actions that have been calculated for an +-- event, but the event hasn't yet been persisted. +-- When the event is persisted the rows are moved over to the +-- event_push_actions table. +CREATE TABLE event_push_actions_staging ( + event_id TEXT NOT NULL, + user_id TEXT NOT NULL, + actions TEXT NOT NULL, + notif SMALLINT NOT NULL, + highlight SMALLINT NOT NULL +); + +CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging(event_id); diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py new file mode 100644 index 0000000000..f6766501d2 --- /dev/null +++ b/synapse/storage/schema/delta/47/state_group_seq.py @@ -0,0 +1,37 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine + + +def run_create(cur, database_engine, *args, **kwargs): + if isinstance(database_engine, PostgresEngine): + # if we already have some state groups, we want to start making new + # ones with a higher id. + cur.execute("SELECT max(id) FROM state_groups") + row = cur.fetchone() + + if row[0] is None: + start_val = 1 + else: + start_val = row[0] + 1 + + cur.execute( + "CREATE SEQUENCE state_group_id_seq START WITH %s", + (start_val, ), + ) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql new file mode 100644 index 0000000000..9248b0b24a --- /dev/null +++ b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql @@ -0,0 +1,17 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('user_ips_last_seen_index', '{}'); diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py new file mode 100644 index 0000000000..2233af87d7 --- /dev/null +++ b/synapse/storage/schema/delta/48/group_unique_indexes.py @@ -0,0 +1,57 @@ +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.storage.engines import PostgresEngine +from synapse.storage.prepare_database import get_statements + +FIX_INDEXES = """ +-- rebuild indexes as uniques +DROP INDEX groups_invites_g_idx; +CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id); +DROP INDEX groups_users_g_idx; +CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id); + +-- rename other indexes to actually match their table names.. +DROP INDEX groups_users_u_idx; +CREATE INDEX group_users_u_idx ON group_users(user_id); +DROP INDEX groups_invites_u_idx; +CREATE INDEX group_invites_u_idx ON group_invites(user_id); +DROP INDEX groups_rooms_g_idx; +CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id); +DROP INDEX groups_rooms_r_idx; +CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); +""" + + +def run_create(cur, database_engine, *args, **kwargs): + rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" + + # remove duplicates from group_users & group_invites tables + cur.execute(""" + DELETE FROM group_users WHERE %s NOT IN ( + SELECT min(%s) FROM group_users GROUP BY group_id, user_id + ); + """ % (rowid, rowid)) + cur.execute(""" + DELETE FROM group_invites WHERE %s NOT IN ( + SELECT min(%s) FROM group_invites GROUP BY group_id, user_id + ); + """ % (rowid, rowid)) + + for statement in get_statements(FIX_INDEXES.splitlines()): + cur.execute(statement) + + +def run_upgrade(*args, **kwargs): + pass diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/schema/delta/48/groups_joinable.sql new file mode 100644 index 0000000000..ce26eaf0c9 --- /dev/null +++ b/synapse/storage/schema/delta/48/groups_joinable.sql @@ -0,0 +1,22 @@ +/* Copyright 2018 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * This isn't a real ENUM because sqlite doesn't support it + * and we use a default of NULL for inserted rows and interpret + * NULL at the python store level as necessary so that existing + * rows are given the correct default policy. + */ +ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite'; diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql index a7ade69986..42e5cb6df5 100644 --- a/synapse/storage/schema/schema_version.sql +++ b/synapse/storage/schema/schema_version.sql @@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas( file TEXT NOT NULL, UNIQUE(version, file) ); + +-- a list of schema files we have loaded on behalf of dynamic modules +CREATE TABLE IF NOT EXISTS applied_module_schemas( + module_name TEXT NOT NULL, + file TEXT NOT NULL, + UNIQUE(module_name, file) +); diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 8f2b3c4435..6ba3e59889 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -13,28 +13,35 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import namedtuple +import logging +import re +import simplejson as json + from twisted.internet import defer from .background_updates import BackgroundUpdateStore from synapse.api.errors import SynapseError from synapse.storage.engines import PostgresEngine, Sqlite3Engine -import logging -import re -import ujson as json - logger = logging.getLogger(__name__) +SearchEntry = namedtuple('SearchEntry', [ + 'key', 'value', 'event_id', 'room_id', 'stream_ordering', + 'origin_server_ts', +]) + class SearchStore(BackgroundUpdateStore): EVENT_SEARCH_UPDATE_NAME = "event_search" EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order" EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" + EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" - def __init__(self, hs): - super(SearchStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(SearchStore, self).__init__(db_conn, hs) self.register_background_update_handler( self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search ) @@ -42,23 +49,35 @@ class SearchStore(BackgroundUpdateStore): self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order ) - self.register_background_update_handler( + + # we used to have a background update to turn the GIN index into a + # GIST one; we no longer do that (obviously) because we actually want + # a GIN index. However, it's possible that some people might still have + # the background update queued, so we register a handler to clear the + # background update. + self.register_noop_background_update( self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME, - self._background_reindex_gist_search + ) + + self.register_background_update_handler( + self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, + self._background_reindex_gin_search ) @defer.inlineCallbacks def _background_reindex_search(self, progress, batch_size): + # we work through the events table from highest stream id to lowest target_min_stream_id = progress["target_min_stream_id_inclusive"] max_stream_id = progress["max_stream_id_exclusive"] rows_inserted = progress.get("rows_inserted", 0) - INSERT_CLUMP_SIZE = 1000 TYPES = ["m.room.name", "m.room.message", "m.room.topic"] def reindex_search_txn(txn): sql = ( - "SELECT stream_ordering, event_id, room_id, type, content FROM events" + "SELECT stream_ordering, event_id, room_id, type, json, " + " origin_server_ts FROM events" + " JOIN event_json USING (room_id, event_id)" " WHERE ? <= stream_ordering AND stream_ordering < ?" " AND (%s)" " ORDER BY stream_ordering DESC" @@ -67,6 +86,10 @@ class SearchStore(BackgroundUpdateStore): txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) + # we could stream straight from the results into + # store_search_entries_txn with a generator function, but that + # would mean having two cursors open on the database at once. + # Instead we just build a list of results. rows = self.cursor_to_dict(txn) if not rows: return 0 @@ -79,9 +102,12 @@ class SearchStore(BackgroundUpdateStore): event_id = row["event_id"] room_id = row["room_id"] etype = row["type"] + stream_ordering = row["stream_ordering"] + origin_server_ts = row["origin_server_ts"] try: - content = json.loads(row["content"]) - except: + event_json = json.loads(row["json"]) + content = event_json["content"] + except Exception: continue if etype == "m.room.message": @@ -93,6 +119,8 @@ class SearchStore(BackgroundUpdateStore): elif etype == "m.room.name": key = "content.name" value = content["name"] + else: + raise Exception("unexpected event type %s" % etype) except (KeyError, AttributeError): # If the event is missing a necessary field then # skip over it. @@ -103,25 +131,16 @@ class SearchStore(BackgroundUpdateStore): # then skip over it continue - event_search_rows.append((event_id, room_id, key, value)) + event_search_rows.append(SearchEntry( + key=key, + value=value, + event_id=event_id, + room_id=room_id, + stream_ordering=stream_ordering, + origin_server_ts=origin_server_ts, + )) - if isinstance(self.database_engine, PostgresEngine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, vector)" - " VALUES (?,?,?,to_tsvector('english', ?))" - ) - elif isinstance(self.database_engine, Sqlite3Engine): - sql = ( - "INSERT INTO event_search (event_id, room_id, key, value)" - " VALUES (?,?,?,?)" - ) - else: - # This should be unreachable. - raise Exception("Unrecognized database engine") - - for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE): - clump = event_search_rows[index:index + INSERT_CLUMP_SIZE] - txn.executemany(sql, clump) + self.store_search_entries_txn(txn, event_search_rows) progress = { "target_min_stream_id_inclusive": target_min_stream_id, @@ -145,25 +164,48 @@ class SearchStore(BackgroundUpdateStore): defer.returnValue(result) @defer.inlineCallbacks - def _background_reindex_gist_search(self, progress, batch_size): + def _background_reindex_gin_search(self, progress, batch_size): + """This handles old synapses which used GIST indexes, if any; + converting them back to be GIN as per the actual schema. + """ + def create_index(conn): conn.rollback() - conn.set_session(autocommit=True) - c = conn.cursor() - c.execute( - "CREATE INDEX CONCURRENTLY event_search_fts_idx_gist" - " ON event_search USING GIST (vector)" - ) + # we have to set autocommit, because postgres refuses to + # CREATE INDEX CONCURRENTLY without it. + conn.set_session(autocommit=True) - c.execute("DROP INDEX event_search_fts_idx") + try: + c = conn.cursor() - conn.set_session(autocommit=False) + # if we skipped the conversion to GIST, we may already/still + # have an event_search_fts_idx; unfortunately postgres 9.4 + # doesn't support CREATE INDEX IF EXISTS so we just catch the + # exception and ignore it. + import psycopg2 + try: + c.execute( + "CREATE INDEX CONCURRENTLY event_search_fts_idx" + " ON event_search USING GIN (vector)" + ) + except psycopg2.ProgrammingError as e: + logger.warn( + "Ignoring error %r when trying to switch from GIST to GIN", + e + ) + + # we should now be able to delete the GIST index. + c.execute( + "DROP INDEX IF EXISTS event_search_fts_idx_gist" + ) + finally: + conn.set_session(autocommit=False) if isinstance(self.database_engine, PostgresEngine): yield self.runWithConnection(create_index) - yield self._end_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME) + yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME) defer.returnValue(1) @defer.inlineCallbacks @@ -242,6 +284,85 @@ class SearchStore(BackgroundUpdateStore): defer.returnValue(num_rows) + def store_event_search_txn(self, txn, event, key, value): + """Add event to the search table + + Args: + txn (cursor): + event (EventBase): + key (str): + value (str): + """ + self.store_search_entries_txn( + txn, + (SearchEntry( + key=key, + value=value, + event_id=event.event_id, + room_id=event.room_id, + stream_ordering=event.internal_metadata.stream_ordering, + origin_server_ts=event.origin_server_ts, + ),), + ) + + def store_search_entries_txn(self, txn, entries): + """Add entries to the search table + + Args: + txn (cursor): + entries (iterable[SearchEntry]): + entries to be added to the table + """ + if isinstance(self.database_engine, PostgresEngine): + sql = ( + "INSERT INTO event_search" + " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)" + " VALUES (?,?,?,to_tsvector('english', ?),?,?)" + ) + + args = (( + entry.event_id, entry.room_id, entry.key, entry.value, + entry.stream_ordering, entry.origin_server_ts, + ) for entry in entries) + + # inserts to a GIN index are normally batched up into a pending + # list, and then all committed together once the list gets to a + # certain size. The trouble with that is that postgres (pre-9.5) + # uses work_mem to determine the length of the list, and work_mem + # is typically very large. + # + # We therefore reduce work_mem while we do the insert. + # + # (postgres 9.5 uses the separate gin_pending_list_limit setting, + # so doesn't suffer the same problem, but changing work_mem will + # be harmless) + # + # Note that we don't need to worry about restoring it on + # exception, because exceptions will cause the transaction to be + # rolled back, including the effects of the SET command. + # + # Also: we use SET rather than SET LOCAL because there's lots of + # other stuff going on in this transaction, which want to have the + # normal work_mem setting. + + txn.execute("SET work_mem='256kB'") + txn.executemany(sql, args) + txn.execute("RESET work_mem") + + elif isinstance(self.database_engine, Sqlite3Engine): + sql = ( + "INSERT INTO event_search (event_id, room_id, key, value)" + " VALUES (?,?,?,?)" + ) + args = (( + entry.event_id, entry.room_id, entry.key, entry.value, + ) for entry in entries) + + txn.executemany(sql, args) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + @defer.inlineCallbacks def search_msgs(self, room_ids, search_term, keys): """Performs a full text search over events with given keys. @@ -407,7 +528,7 @@ class SearchStore(BackgroundUpdateStore): origin_server_ts, stream = pagination_token.split(",") origin_server_ts = int(origin_server_ts) stream = int(stream) - except: + except Exception: raise SynapseError(400, "Invalid pagination token") clauses.append( diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index e1dca927d7..9e6eaaa532 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -22,12 +22,12 @@ from synapse.crypto.event_signing import compute_event_reference_hash from synapse.util.caches.descriptors import cached, cachedList -class SignatureStore(SQLBaseStore): - """Persistence for event signatures and hashes""" - +class SignatureWorkerStore(SQLBaseStore): @cached() def get_event_reference_hash(self, event_id): - return self._get_event_reference_hashes_txn(event_id) + # This is a dummy function to allow get_event_reference_hashes + # to use its cache + raise NotImplementedError() @cachedList(cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1) @@ -72,7 +72,11 @@ class SignatureStore(SQLBaseStore): " WHERE event_id = ?" ) txn.execute(query, (event_id, )) - return {k: v for k, v in txn.fetchall()} + return {k: v for k, v in txn} + + +class SignatureStore(SignatureWorkerStore): + """Persistence for event signatures and hashes""" def _store_event_reference_hashes_txn(self, txn, events): """Store a hash for a PDU diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 84482d8285..ffa4246031 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -13,14 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList -from synapse.util.caches import intern_string -from synapse.storage.engines import PostgresEngine +from collections import namedtuple +import logging from twisted.internet import defer -import logging +from synapse.storage.background_updates import BackgroundUpdateStore +from synapse.storage.engines import PostgresEngine +from synapse.util.caches import intern_string, CACHE_SIZE_FACTOR +from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.dictionary_cache import DictionaryCache +from synapse.util.stringutils import to_ascii +from ._base import SQLBaseStore logger = logging.getLogger(__name__) @@ -28,45 +32,97 @@ logger = logging.getLogger(__name__) MAX_STATE_DELTA_HOPS = 100 -class StateStore(SQLBaseStore): - """ Keeps track of the state at a given event. +class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))): + """Return type of get_state_group_delta that implements __len__, which lets + us use the itrable flag when caching + """ + __slots__ = [] - This is done by the concept of `state groups`. Every event is a assigned - a state group (identified by an arbitrary string), which references a - collection of state events. The current state of an event is then the - collection of state events referenced by the event's state group. + def __len__(self): + return len(self.delta_ids) if self.delta_ids else 0 - Hence, every change in the current state causes a new state group to be - generated. However, if no change happens (e.g., if we get a message event - with only one parent it inherits the state group from its parent.) - There are three tables: - * `state_groups`: Stores group name, first event with in the group and - room id. - * `event_to_state_groups`: Maps events to state groups. - * `state_groups_state`: Maps state group to state events. +class StateGroupWorkerStore(SQLBaseStore): + """The parts of StateGroupStore that can be called from workers. """ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" - def __init__(self, hs): - super(StateStore, self).__init__(hs) - self.register_background_update_handler( - self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, - self._background_deduplicate_state, + def __init__(self, db_conn, hs): + super(StateGroupWorkerStore, self).__init__(db_conn, hs) + + self._state_group_cache = DictionaryCache( + "*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR ) - self.register_background_update_handler( - self.STATE_GROUP_INDEX_UPDATE_NAME, - self._background_index_state, + + @cached(max_entries=100000, iterable=True) + def get_current_state_ids(self, room_id): + """Get the current state event ids for a room based on the + current_state_events table. + + Args: + room_id (str) + + Returns: + deferred: dict of (type, state_key) -> event_id + """ + def _get_current_state_ids_txn(txn): + txn.execute( + """SELECT type, state_key, event_id FROM current_state_events + WHERE room_id = ? + """, + (room_id,) + ) + + return { + (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn + } + + return self.runInteraction( + "get_current_state_ids", + _get_current_state_ids_txn, ) - self.register_background_index_update( - self.CURRENT_STATE_INDEX_UPDATE_NAME, - index_name="current_state_events_member_index", - table="current_state_events", - columns=["state_key"], - where_clause="type='m.room.member'", + + @cached(max_entries=10000, iterable=True) + def get_state_group_delta(self, state_group): + """Given a state group try to return a previous group and a delta between + the old and the new. + + Returns: + (prev_group, delta_ids), where both may be None. + """ + def _get_state_group_delta_txn(txn): + prev_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={ + "state_group": state_group, + }, + retcol="prev_state_group", + allow_none=True, + ) + + if not prev_group: + return _GetStateGroupDelta(None, None) + + delta_ids = self._simple_select_list_txn( + txn, + table="state_groups_state", + keyvalues={ + "state_group": state_group, + }, + retcols=("type", "state_key", "event_id",) + ) + + return _GetStateGroupDelta(prev_group, { + (row["type"], row["state_key"]): row["event_id"] + for row in delta_ids + }) + return self.runInteraction( + "get_state_group_delta", + _get_state_group_delta_txn, ) @defer.inlineCallbacks @@ -78,12 +134,26 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.values()) + groups = set(event_to_groups.itervalues()) group_to_state = yield self._get_state_for_groups(groups) defer.returnValue(group_to_state) @defer.inlineCallbacks + def get_state_ids_for_group(self, state_group): + """Get the state IDs for the given state group + + Args: + state_group (int) + + Returns: + Deferred[dict]: Resolves to a map of (type, state_key) -> event_id + """ + group_to_state = yield self._get_state_for_groups((state_group,)) + + defer.returnValue(group_to_state[state_group]) + + @defer.inlineCallbacks def get_state_groups(self, room_id, event_ids): """ Get the state groups for the given list of event_ids @@ -96,156 +166,21 @@ class StateStore(SQLBaseStore): state_event_map = yield self.get_events( [ - ev_id for group_ids in group_to_ids.values() - for ev_id in group_ids.values() + ev_id for group_ids in group_to_ids.itervalues() + for ev_id in group_ids.itervalues() ], get_prev_content=False ) defer.returnValue({ group: [ - state_event_map[v] for v in event_id_map.values() if v in state_event_map + state_event_map[v] for v in event_id_map.itervalues() + if v in state_event_map ] - for group, event_id_map in group_to_ids.items() + for group, event_id_map in group_to_ids.iteritems() }) - def _have_persisted_state_group_txn(self, txn, state_group): - txn.execute( - "SELECT count(*) FROM state_groups WHERE id = ?", - (state_group,) - ) - row = txn.fetchone() - return row and row[0] - - def _store_mult_state_groups_txn(self, txn, events_and_contexts): - state_groups = {} - for event, context in events_and_contexts: - if event.internal_metadata.is_outlier(): - continue - - if context.current_state_ids is None: - continue - - state_groups[event.event_id] = context.state_group - - if self._have_persisted_state_group_txn(txn, context.state_group): - continue - - self._simple_insert_txn( - txn, - table="state_groups", - values={ - "id": context.state_group, - "room_id": event.room_id, - "event_id": event.event_id, - }, - ) - - # We persist as a delta if we can, while also ensuring the chain - # of deltas isn't tooo long, as otherwise read performance degrades. - if context.prev_group: - potential_hops = self._count_state_group_hops_txn( - txn, context.prev_group - ) - if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS: - self._simple_insert_txn( - txn, - table="state_group_edges", - values={ - "state_group": context.state_group, - "prev_state_group": context.prev_group, - }, - ) - - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.delta_ids.items() - ], - ) - else: - self._simple_insert_many_txn( - txn, - table="state_groups_state", - values=[ - { - "state_group": context.state_group, - "room_id": event.room_id, - "type": key[0], - "state_key": key[1], - "event_id": state_id, - } - for key, state_id in context.current_state_ids.items() - ], - ) - - self._simple_insert_many_txn( - txn, - table="event_to_state_groups", - values=[ - { - "state_group": state_group_id, - "event_id": event_id, - } - for event_id, state_group_id in state_groups.items() - ], - ) - - def _count_state_group_hops_txn(self, txn, state_group): - """Given a state group, count how many hops there are in the tree. - - This is used to ensure the delta chains don't get too long. - """ - if isinstance(self.database_engine, PostgresEngine): - sql = (""" - WITH RECURSIVE state(state_group) AS ( - VALUES(?::bigint) - UNION ALL - SELECT prev_state_group FROM state_group_edges e, state s - WHERE s.state_group = e.state_group - ) - SELECT count(*) FROM state; - """) - - txn.execute(sql, (state_group,)) - row = txn.fetchone() - if row and row[0]: - return row[0] - else: - return 0 - else: - # We don't use WITH RECURSIVE on sqlite3 as there are distributions - # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) - next_group = state_group - count = 0 - - while next_group: - next_group = self._simple_select_one_onecol_txn( - txn, - table="state_group_edges", - keyvalues={"state_group": next_group}, - retcol="prev_state_group", - allow_none=True, - ) - if next_group: - count += 1 - - return count - - @cached(num_args=2, max_entries=100000, iterable=True) - def _get_state_group_from_group(self, group, types): - raise NotImplementedError() - - @cachedList(cached_method_name="_get_state_group_from_group", - list_name="groups", num_args=2, inlineCallbacks=True) + @defer.inlineCallbacks def _get_state_groups_from_groups(self, groups, types): """Returns dictionary state_group -> (dict of (type, state_key) -> event id) """ @@ -305,6 +240,9 @@ class StateStore(SQLBaseStore): ( "AND type = ? AND state_key = ?", (etype, state_key) + ) if state_key is not None else ( + "AND type = ?", + (etype,) ) for etype, state_key in types ] @@ -319,15 +257,24 @@ class StateStore(SQLBaseStore): args.extend(where_args) txn.execute(sql % (where_clause,), args) - rows = self.cursor_to_dict(txn) - for row in rows: - key = (row["type"], row["state_key"]) - results[group][key] = row["event_id"] + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id else: + where_args = [] + where_clauses = [] + wildcard_types = False if types is not None: - where_clause = "AND (%s)" % ( - " OR ".join(["(type = ? AND state_key = ?)"] * len(types)), - ) + for typ in types: + if typ[1] is None: + where_clauses.append("(type = ?)") + where_args.extend(typ[0]) + wildcard_types = True + else: + where_clauses.append("(type = ? AND state_key = ?)") + where_args.extend([typ[0], typ[1]]) + where_clause = "AND (%s)" % (" OR ".join(where_clauses)) else: where_clause = "" @@ -344,23 +291,30 @@ class StateStore(SQLBaseStore): # after we finish deduping state, which requires this func) args = [next_group] if types: - args.extend(i for typ in types for i in typ) + args.extend(where_args) txn.execute( "SELECT type, state_key, event_id FROM state_groups_state" " WHERE state_group = ? %s" % (where_clause,), args ) - rows = txn.fetchall() - results[group].update({ - (typ, state_key): event_id - for typ, state_key, event_id in rows + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn if (typ, state_key) not in results[group] - }) + ) - # If the lengths match then we must have all the types, - # so no need to go walk further down the tree. - if types is not None and len(results[group]) == len(types): + # If the number of entries in the (type,state_key)->event_id dict + # matches the number of (type,state_keys) types we were searching + # for, then we must have found them all, so no need to go walk + # further down the tree... UNLESS our types filter contained + # wildcards (i.e. Nones) in which case we have to do an exhaustive + # search + if ( + types is not None and + not wildcard_types and + len(results[group]) == len(types) + ): break next_group = self._simple_select_one_onecol_txn( @@ -393,21 +347,21 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.values()) + groups = set(event_to_groups.itervalues()) group_to_state = yield self._get_state_for_groups(groups, types) state_event_map = yield self.get_events( - [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()], get_prev_content=False ) event_to_state = { event_id: { k: state_event_map[v] - for k, v in group_to_state[group].items() + for k, v in group_to_state[group].iteritems() if v in state_event_map } - for event_id, group in event_to_groups.items() + for event_id, group in event_to_groups.iteritems() } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -430,12 +384,12 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.values()) + groups = set(event_to_groups.itervalues()) group_to_state = yield self._get_state_for_groups(groups, types) event_to_state = { event_id: group_to_state[group] - for event_id, group in event_to_groups.items() + for event_id, group in event_to_groups.iteritems() } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -474,8 +428,8 @@ class StateStore(SQLBaseStore): state_map = yield self.get_state_ids_for_events([event_id], types) defer.returnValue(state_map[event_id]) - @cached(num_args=2, max_entries=10000) - def _get_state_group_for_event(self, room_id, event_id): + @cached(max_entries=50000) + def _get_state_group_for_event(self, event_id): return self._simple_select_one_onecol( table="event_to_state_groups", keyvalues={ @@ -517,20 +471,22 @@ class StateStore(SQLBaseStore): where a `state_key` of `None` matches all state_keys for the `type`. """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} missing_types = set() + for typ, state_key in types: + key = (typ, state_key) if state_key is None: type_to_key[typ] = None - missing_types.add((typ, state_key)) + missing_types.add(key) else: if type_to_key.get(typ, object()) is not None: type_to_key.setdefault(typ, set()).add(state_key) - if (typ, state_key) not in state_dict_ids: - missing_types.add((typ, state_key)) + if key not in state_dict_ids and key not in known_absent: + missing_types.add(key) sentinel = object() @@ -544,10 +500,10 @@ class StateStore(SQLBaseStore): return True return False - got_all = not (missing_types or types is None) + got_all = is_all or not missing_types return { - k: v for k, v in state_dict_ids.items() + k: v for k, v in state_dict_ids.iteritems() if include(k[0], k[1]) }, missing_types, got_all @@ -561,7 +517,7 @@ class StateStore(SQLBaseStore): Args: group: The state group to lookup """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, _, state_dict_ids = self._state_group_cache.get(group) return state_dict_ids, is_all @@ -578,7 +534,7 @@ class StateStore(SQLBaseStore): missing_groups = [] if types is not None: for group in set(groups): - state_dict_ids, missing_types, got_all = self._get_some_state_from_cache( + state_dict_ids, _, got_all = self._get_some_state_from_cache( group, types ) results[group] = state_dict_ids @@ -606,46 +562,247 @@ class StateStore(SQLBaseStore): # Now we want to update the cache with all the things we fetched # from the database. - for group, group_state_dict in group_to_state_dict.items(): - if types: - # We delibrately put key -> None mappings into the cache to - # cache absence of the key, on the assumption that if we've - # explicitly asked for some types then we will probably ask - # for them again. - state_dict = { - (intern_string(etype), intern_string(state_key)): None - for (etype, state_key) in types - } - state_dict.update(results[group]) - results[group] = state_dict - else: - state_dict = results[group] - - state_dict.update({ - (intern_string(k[0]), intern_string(k[1])): v - for k, v in group_state_dict.items() - }) + for group, group_state_dict in group_to_state_dict.iteritems(): + state_dict = results[group] + + state_dict.update( + ((intern_string(k[0]), intern_string(k[1])), to_ascii(v)) + for k, v in group_state_dict.iteritems() + ) self._state_group_cache.update( cache_seq_num, key=group, value=state_dict, full=(types is None), + known_absent=types, ) - # Remove all the entries with None values. The None values were just - # used for bookkeeping in the cache. - for group, state_dict in results.items(): - results[group] = { - key: event_id - for key, event_id in state_dict.items() - if event_id - } - defer.returnValue(results) - def get_next_state_group(self): - return self._state_groups_id_gen.get_next() + def store_state_group(self, event_id, room_id, prev_group, delta_ids, + current_state_ids): + """Store a new set of state, returning a newly assigned state group. + + Args: + event_id (str): The event ID for which the state was calculated + room_id (str) + prev_group (int|None): A previous state group for the room, optional. + delta_ids (dict|None): The delta between state at `prev_group` and + `current_state_ids`, if `prev_group` was given. Same format as + `current_state_ids`. + current_state_ids (dict): The state to store. Map of (type, state_key) + to event_id. + + Returns: + Deferred[int]: The state group ID + """ + def _store_state_group_txn(txn): + if current_state_ids is None: + # AFAIK, this can never happen + raise Exception("current_state_ids cannot be None") + + state_group = self.database_engine.get_next_state_group_id(txn) + + self._simple_insert_txn( + txn, + table="state_groups", + values={ + "id": state_group, + "room_id": room_id, + "event_id": event_id, + }, + ) + + # We persist as a delta if we can, while also ensuring the chain + # of deltas isn't tooo long, as otherwise read performance degrades. + if prev_group: + is_in_db = self._simple_select_one_onecol_txn( + txn, + table="state_groups", + keyvalues={"id": prev_group}, + retcol="id", + allow_none=True, + ) + if not is_in_db: + raise Exception( + "Trying to persist state with unpersisted prev_group: %r" + % (prev_group,) + ) + + potential_hops = self._count_state_group_hops_txn( + txn, prev_group + ) + if prev_group and potential_hops < MAX_STATE_DELTA_HOPS: + self._simple_insert_txn( + txn, + table="state_group_edges", + values={ + "state_group": state_group, + "prev_state_group": prev_group, + }, + ) + + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in delta_ids.iteritems() + ], + ) + else: + self._simple_insert_many_txn( + txn, + table="state_groups_state", + values=[ + { + "state_group": state_group, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + "event_id": state_id, + } + for key, state_id in current_state_ids.iteritems() + ], + ) + + # Prefill the state group cache with this group. + # It's fine to use the sequence like this as the state group map + # is immutable. (If the map wasn't immutable then this prefill could + # race with another update) + txn.call_after( + self._state_group_cache.update, + self._state_group_cache.sequence, + key=state_group, + value=dict(current_state_ids), + full=True, + ) + + return state_group + + return self.runInteraction("store_state_group", _store_state_group_txn) + + def _count_state_group_hops_txn(self, txn, state_group): + """Given a state group, count how many hops there are in the tree. + + This is used to ensure the delta chains don't get too long. + """ + if isinstance(self.database_engine, PostgresEngine): + sql = (""" + WITH RECURSIVE state(state_group) AS ( + VALUES(?::bigint) + UNION ALL + SELECT prev_state_group FROM state_group_edges e, state s + WHERE s.state_group = e.state_group + ) + SELECT count(*) FROM state; + """) + + txn.execute(sql, (state_group,)) + row = txn.fetchone() + if row and row[0]: + return row[0] + else: + return 0 + else: + # We don't use WITH RECURSIVE on sqlite3 as there are distributions + # that ship with an sqlite3 version that doesn't support it (e.g. wheezy) + next_group = state_group + count = 0 + + while next_group: + next_group = self._simple_select_one_onecol_txn( + txn, + table="state_group_edges", + keyvalues={"state_group": next_group}, + retcol="prev_state_group", + allow_none=True, + ) + if next_group: + count += 1 + + return count + + +class StateStore(StateGroupWorkerStore, BackgroundUpdateStore): + """ Keeps track of the state at a given event. + + This is done by the concept of `state groups`. Every event is a assigned + a state group (identified by an arbitrary string), which references a + collection of state events. The current state of an event is then the + collection of state events referenced by the event's state group. + + Hence, every change in the current state causes a new state group to be + generated. However, if no change happens (e.g., if we get a message event + with only one parent it inherits the state group from its parent.) + + There are three tables: + * `state_groups`: Stores group name, first event with in the group and + room id. + * `event_to_state_groups`: Maps events to state groups. + * `state_groups_state`: Maps state group to state events. + """ + + STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" + STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" + + def __init__(self, db_conn, hs): + super(StateStore, self).__init__(db_conn, hs) + self.register_background_update_handler( + self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, + self._background_deduplicate_state, + ) + self.register_background_update_handler( + self.STATE_GROUP_INDEX_UPDATE_NAME, + self._background_index_state, + ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + + def _store_event_state_mappings_txn(self, txn, events_and_contexts): + state_groups = {} + for event, context in events_and_contexts: + if event.internal_metadata.is_outlier(): + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group + continue + + state_groups[event.event_id] = context.state_group + + self._simple_insert_many_txn( + txn, + table="event_to_state_groups", + values=[ + { + "state_group": state_group_id, + "event_id": event_id, + } + for event_id, state_group_id in state_groups.iteritems() + ], + ) + + for event_id, state_group_id in state_groups.iteritems(): + txn.call_after( + self._get_state_group_for_event.prefill, + (event_id,), state_group_id + ) @defer.inlineCallbacks def _background_deduplicate_state(self, progress, batch_size): @@ -727,7 +884,7 @@ class StateStore(SQLBaseStore): # of keys delta_state = { - key: value for key, value in curr_state.items() + key: value for key, value in curr_state.iteritems() if prev_state.get(key, None) != value } @@ -767,7 +924,7 @@ class StateStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in delta_state.items() + for key, state_id in delta_state.iteritems() ], ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 200d124632..f0784ba137 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -35,15 +35,20 @@ what sort order was used: from twisted.internet import defer -from ._base import SQLBaseStore +from synapse.storage._base import SQLBaseStore +from synapse.storage.events import EventsWorkerStore + from synapse.util.caches.descriptors import cached -from synapse.api.constants import EventTypes from synapse.types import RoomStreamToken -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.logcontext import make_deferred_yieldable, run_in_background from synapse.storage.engines import PostgresEngine, Sqlite3Engine +import abc import logging +from six.moves import range + logger = logging.getLogger(__name__) @@ -143,81 +148,41 @@ def filter_to_clause(event_filter): return " AND ".join(clauses), args -class StreamStore(SQLBaseStore): - @defer.inlineCallbacks - def get_appservice_room_stream(self, service, from_key, to_key, limit=0): - # NB this lives here instead of appservice.py so we can reuse the - # 'private' StreamToken class in this file. - if limit: - limit = max(limit, MAX_STREAM_SIZE) - else: - limit = MAX_STREAM_SIZE - - # From and to keys should be integers from ordering. - from_id = RoomStreamToken.parse_stream_token(from_key) - to_id = RoomStreamToken.parse_stream_token(to_key) - - if from_key == to_key: - defer.returnValue(([], to_key)) - return - - # select all the events between from/to with a sensible limit - sql = ( - "SELECT e.event_id, e.room_id, e.type, s.state_key, " - "e.stream_ordering FROM events AS e " - "LEFT JOIN state_events as s ON " - "e.event_id = s.event_id " - "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? " - "ORDER BY stream_ordering ASC LIMIT %(limit)d " - ) % { - "limit": limit - } - - def f(txn): - # pull out all the events between the tokens - txn.execute(sql, (from_id.stream, to_id.stream,)) - rows = self.cursor_to_dict(txn) - - # Logic: - # - We want ALL events which match the AS room_id regex - # - We want ALL events which match the rooms represented by the AS - # room_alias regex - # - We want ALL events for rooms that AS users have joined. - # This is currently supported via get_app_service_rooms (which is - # used for the Notifier listener rooms). We can't reasonably make a - # SQL query for these room IDs, so we'll pull all the events between - # from/to and filter in python. - rooms_for_as = self._get_app_service_rooms_txn(txn, service) - room_ids_for_as = [r.room_id for r in rooms_for_as] - - def app_service_interested(row): - if row["room_id"] in room_ids_for_as: - return True +class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): + """This is an abstract base class where subclasses must implement + `get_room_max_stream_ordering` and `get_room_min_stream_ordering` + which can be called in the initializer. + """ - if row["type"] == EventTypes.Member: - if service.is_interested_in_user(row.get("state_key")): - return True - return False + __metaclass__ = abc.ABCMeta - return [r for r in rows if app_service_interested(r)] + def __init__(self, db_conn, hs): + super(StreamWorkerStore, self).__init__(db_conn, hs) - rows = yield self.runInteraction("get_appservice_room_stream", f) - - ret = yield self._get_events( - [r["event_id"] for r in rows], - get_prev_content=True + events_max = self.get_room_max_stream_ordering() + event_cache_prefill, min_event_val = self._get_cache_dict( + db_conn, "events", + entity_column="room_id", + stream_column="stream_ordering", + max_value=events_max, + ) + self._events_stream_cache = StreamChangeCache( + "EventsRoomStreamChangeCache", min_event_val, + prefilled_cache=event_cache_prefill, + ) + self._membership_stream_cache = StreamChangeCache( + "MembershipStreamChangeCache", events_max, ) - self._set_before_and_after(ret, rows, topo_order=from_id is None) + self._stream_order_on_start = self.get_room_max_stream_ordering() - if rows: - key = "s%d" % max(r["stream_ordering"] for r in rows) - else: - # Assume we didn't get anything because there was nothing to - # get. - key = to_key + @abc.abstractmethod + def get_room_max_stream_ordering(self): + raise NotImplementedError() - defer.returnValue((ret, key)) + @abc.abstractmethod + def get_room_min_stream_ordering(self): + raise NotImplementedError() @defer.inlineCallbacks def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0, @@ -233,13 +198,14 @@ class StreamStore(SQLBaseStore): results = {} room_ids = list(room_ids) - for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)): - res = yield preserve_context_over_deferred(defer.gatherResults([ - preserve_fn(self.get_room_events_stream_for_room)( + for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)): + res = yield make_deferred_yieldable(defer.gatherResults([ + run_in_background( + self.get_room_events_stream_for_room, room_id, from_key, to_key, limit, order=order, ) for room_id in rm_ids - ])) + ], consumeErrors=True)) results.update(dict(zip(rm_ids, res))) defer.returnValue(results) @@ -381,88 +347,6 @@ class StreamStore(SQLBaseStore): defer.returnValue(ret) @defer.inlineCallbacks - def paginate_room_events(self, room_id, from_key, to_key=None, - direction='b', limit=-1, event_filter=None): - # Tokens really represent positions between elements, but we use - # the convention of pointing to the event before the gap. Hence - # we have a bit of asymmetry when it comes to equalities. - args = [False, room_id] - if direction == 'b': - order = "DESC" - bounds = upper_bound( - RoomStreamToken.parse(from_key), self.database_engine - ) - if to_key: - bounds = "%s AND %s" % (bounds, lower_bound( - RoomStreamToken.parse(to_key), self.database_engine - )) - else: - order = "ASC" - bounds = lower_bound( - RoomStreamToken.parse(from_key), self.database_engine - ) - if to_key: - bounds = "%s AND %s" % (bounds, upper_bound( - RoomStreamToken.parse(to_key), self.database_engine - )) - - filter_clause, filter_args = filter_to_clause(event_filter) - - if filter_clause: - bounds += " AND " + filter_clause - args.extend(filter_args) - - if int(limit) > 0: - args.append(int(limit)) - limit_str = " LIMIT ?" - else: - limit_str = "" - - sql = ( - "SELECT * FROM events" - " WHERE outlier = ? AND room_id = ? AND %(bounds)s" - " ORDER BY topological_ordering %(order)s," - " stream_ordering %(order)s %(limit)s" - ) % { - "bounds": bounds, - "order": order, - "limit": limit_str - } - - def f(txn): - txn.execute(sql, args) - - rows = self.cursor_to_dict(txn) - - if rows: - topo = rows[-1]["topological_ordering"] - toke = rows[-1]["stream_ordering"] - if direction == 'b': - # Tokens are positions between events. - # This token points *after* the last event in the chunk. - # We need it to point to the event before it in the chunk - # when we are going backwards so we subtract one from the - # stream part. - toke -= 1 - next_token = str(RoomStreamToken(topo, toke)) - else: - # TODO (erikj): We should work out what to do here instead. - next_token = to_key if to_key else from_key - - return rows, next_token, - - rows, token = yield self.runInteraction("paginate_room_events", f) - - events = yield self._get_events( - [r["event_id"] for r in rows], - get_prev_content=True - ) - - self._set_before_and_after(events, rows) - - defer.returnValue((events, token)) - - @defer.inlineCallbacks def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None): rows, token = yield self.get_recent_event_ids_for_room( room_id, limit, end_token, from_token @@ -534,6 +418,33 @@ class StreamStore(SQLBaseStore): "get_recent_events_for_room", get_recent_events_for_room_txn ) + def get_room_event_after_stream_ordering(self, room_id, stream_ordering): + """Gets details of the first event in a room at or after a stream ordering + + Args: + room_id (str): + stream_ordering (int): + + Returns: + Deferred[(int, int, str)]: + (stream ordering, topological ordering, event_id) + """ + def _f(txn): + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering >= ?" + " AND NOT outlier" + " ORDER BY stream_ordering" + " LIMIT 1" + ) + txn.execute(sql, (room_id, stream_ordering, )) + return txn.fetchone() + + return self.runInteraction( + "get_room_event_after_stream_ordering", _f, + ) + @defer.inlineCallbacks def get_room_events_max_id(self, room_id=None): """Returns the current token for rooms stream. @@ -542,7 +453,7 @@ class StreamStore(SQLBaseStore): `room_id` causes it to return the current room specific topological token. """ - token = yield self._stream_id_gen.get_current_token() + token = yield self.get_room_max_stream_ordering() if room_id is None: defer.returnValue("s%d" % (token,)) else: @@ -552,12 +463,6 @@ class StreamStore(SQLBaseStore): ) defer.returnValue("t%d-%d" % (topo, token)) - def get_room_max_stream_ordering(self): - return self._stream_id_gen.get_current_token() - - def get_room_min_stream_ordering(self): - return self._backfill_id_gen.get_current_token() - def get_stream_token_for_event(self, event_id): """The stream token for an event Args: @@ -829,3 +734,96 @@ class StreamStore(SQLBaseStore): updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) + + def has_room_changed_since(self, room_id, stream_id): + return self._events_stream_cache.has_entity_changed(room_id, stream_id) + + +class StreamStore(StreamWorkerStore): + def get_room_max_stream_ordering(self): + return self._stream_id_gen.get_current_token() + + def get_room_min_stream_ordering(self): + return self._backfill_id_gen.get_current_token() + + @defer.inlineCallbacks + def paginate_room_events(self, room_id, from_key, to_key=None, + direction='b', limit=-1, event_filter=None): + # Tokens really represent positions between elements, but we use + # the convention of pointing to the event before the gap. Hence + # we have a bit of asymmetry when it comes to equalities. + args = [False, room_id] + if direction == 'b': + order = "DESC" + bounds = upper_bound( + RoomStreamToken.parse(from_key), self.database_engine + ) + if to_key: + bounds = "%s AND %s" % (bounds, lower_bound( + RoomStreamToken.parse(to_key), self.database_engine + )) + else: + order = "ASC" + bounds = lower_bound( + RoomStreamToken.parse(from_key), self.database_engine + ) + if to_key: + bounds = "%s AND %s" % (bounds, upper_bound( + RoomStreamToken.parse(to_key), self.database_engine + )) + + filter_clause, filter_args = filter_to_clause(event_filter) + + if filter_clause: + bounds += " AND " + filter_clause + args.extend(filter_args) + + if int(limit) > 0: + args.append(int(limit)) + limit_str = " LIMIT ?" + else: + limit_str = "" + + sql = ( + "SELECT * FROM events" + " WHERE outlier = ? AND room_id = ? AND %(bounds)s" + " ORDER BY topological_ordering %(order)s," + " stream_ordering %(order)s %(limit)s" + ) % { + "bounds": bounds, + "order": order, + "limit": limit_str + } + + def f(txn): + txn.execute(sql, args) + + rows = self.cursor_to_dict(txn) + + if rows: + topo = rows[-1]["topological_ordering"] + toke = rows[-1]["stream_ordering"] + if direction == 'b': + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + toke -= 1 + next_token = str(RoomStreamToken(topo, toke)) + else: + # TODO (erikj): We should work out what to do here instead. + next_token = to_key if to_key else from_key + + return rows, next_token, + + rows, token = yield self.runInteraction("paginate_room_events", f) + + events = yield self._get_events( + [r["event_id"] for r in rows], + get_prev_content=True + ) + + self._set_before_and_after(events, rows) + + defer.returnValue((events, token)) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 5a2c1aa59b..6671d3cfca 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2014-2016 OpenMarket Ltd +# Copyright 2018 New Vector Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,25 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import SQLBaseStore +from synapse.storage.account_data import AccountDataWorkerStore + from synapse.util.caches.descriptors import cached from twisted.internet import defer -import ujson as json +import simplejson as json import logging -logger = logging.getLogger(__name__) +from six.moves import range +logger = logging.getLogger(__name__) -class TagsStore(SQLBaseStore): - def get_max_account_data_stream_id(self): - """Get the current max stream id for the private user data stream - - Returns: - A deferred int. - """ - return self._account_data_id_gen.get_current_token() +class TagsWorkerStore(AccountDataWorkerStore): @cached() def get_tags_for_user(self, user_id): """Get all the tags for a user. @@ -95,7 +91,7 @@ class TagsStore(SQLBaseStore): for stream_id, user_id, room_id in tag_ids: txn.execute(sql, (user_id, room_id)) tags = [] - for tag, content in txn.fetchall(): + for tag, content in txn: tags.append(json.dumps(tag) + ":" + content) tag_json = "{" + ",".join(tags) + "}" results.append((stream_id, user_id, room_id, tag_json)) @@ -104,7 +100,7 @@ class TagsStore(SQLBaseStore): batch_size = 50 results = [] - for i in xrange(0, len(tag_ids), batch_size): + for i in range(0, len(tag_ids), batch_size): tags = yield self.runInteraction( "get_all_updated_tag_content", get_tag_content, @@ -132,7 +128,7 @@ class TagsStore(SQLBaseStore): " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) - room_ids = [row[0] for row in txn.fetchall()] + room_ids = [row[0] for row in txn] return room_ids changed = self._account_data_stream_cache.has_entity_changed( @@ -170,6 +166,8 @@ class TagsStore(SQLBaseStore): row["tag"]: json.loads(row["content"]) for row in rows }) + +class TagsStore(TagsWorkerStore): @defer.inlineCallbacks def add_tag_to_room(self, user_id, room_id, tag, content): """Add a tag to a room for a user. diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py index 809fdd311f..f825264ea9 100644 --- a/synapse/storage/transactions.py +++ b/synapse/storage/transactions.py @@ -23,7 +23,7 @@ from canonicaljson import encode_canonical_json from collections import namedtuple import logging -import ujson as json +import simplejson as json logger = logging.getLogger(__name__) @@ -46,8 +46,8 @@ class TransactionStore(SQLBaseStore): """A collection of queries for handling PDUs. """ - def __init__(self, hs): - super(TransactionStore, self).__init__(hs) + def __init__(self, db_conn, hs): + super(TransactionStore, self).__init__(db_conn, hs) self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py new file mode 100644 index 0000000000..d6e289ffbe --- /dev/null +++ b/synapse/storage/user_directory.py @@ -0,0 +1,764 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 Vector Creations Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import SQLBaseStore + +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.api.constants import EventTypes, JoinRules +from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import get_domain_from_id, get_localpart_from_id + +import re +import logging + +logger = logging.getLogger(__name__) + + +class UserDirectoryStore(SQLBaseStore): + @cachedInlineCallbacks(cache_context=True) + def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context): + """Check if the room is either world_readable or publically joinable + """ + current_state_ids = yield self.get_current_state_ids( + room_id, on_invalidate=cache_context.invalidate + ) + + join_rules_id = current_state_ids.get((EventTypes.JoinRules, "")) + if join_rules_id: + join_rule_ev = yield self.get_event(join_rules_id, allow_none=True) + if join_rule_ev: + if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC: + defer.returnValue(True) + + hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, "")) + if hist_vis_id: + hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True) + if hist_vis_ev: + if hist_vis_ev.content.get("history_visibility") == "world_readable": + defer.returnValue(True) + + defer.returnValue(False) + + @defer.inlineCallbacks + def add_users_to_public_room(self, room_id, user_ids): + """Add user to the list of users in public rooms + + Args: + room_id (str): A room_id that all users are in that is world_readable + or publically joinable + user_ids (list(str)): Users to add + """ + yield self._simple_insert_many( + table="users_in_public_rooms", + values=[ + { + "user_id": user_id, + "room_id": room_id, + } + for user_id in user_ids + ], + desc="add_users_to_public_room" + ) + for user_id in user_ids: + self.get_user_in_public_room.invalidate((user_id,)) + + def add_profiles_to_user_dir(self, room_id, users_with_profile): + """Add profiles to the user directory + + Args: + room_id (str): A room_id that all users are joined to + users_with_profile (dict): Users to add to directory in the form of + mapping of user_id -> ProfileInfo + """ + if isinstance(self.database_engine, PostgresEngine): + # We weight the loclpart most highly, then display name and finally + # server name + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + args = ( + ( + user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id), + profile.display_name, + ) + for user_id, profile in users_with_profile.iteritems() + ) + elif isinstance(self.database_engine, Sqlite3Engine): + sql = """ + INSERT INTO user_directory_search(user_id, value) + VALUES (?,?) + """ + args = ( + ( + user_id, + "%s %s" % (user_id, p.display_name,) if p.display_name else user_id + ) + for user_id, p in users_with_profile.iteritems() + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + def _add_profiles_to_user_dir_txn(txn): + txn.executemany(sql, args) + self._simple_insert_many_txn( + txn, + table="user_directory", + values=[ + { + "user_id": user_id, + "room_id": room_id, + "display_name": profile.display_name, + "avatar_url": profile.avatar_url, + } + for user_id, profile in users_with_profile.iteritems() + ] + ) + for user_id in users_with_profile: + txn.call_after( + self.get_user_in_directory.invalidate, (user_id,) + ) + + return self.runInteraction( + "add_profiles_to_user_dir", _add_profiles_to_user_dir_txn + ) + + @defer.inlineCallbacks + def update_user_in_user_dir(self, user_id, room_id): + yield self._simple_update_one( + table="user_directory", + keyvalues={"user_id": user_id}, + updatevalues={"room_id": room_id}, + desc="update_user_in_user_dir", + ) + self.get_user_in_directory.invalidate((user_id,)) + + def update_profile_in_user_dir(self, user_id, display_name, avatar_url, room_id): + def _update_profile_in_user_dir_txn(txn): + new_entry = self._simple_upsert_txn( + txn, + table="user_directory", + keyvalues={"user_id": user_id}, + insertion_values={"room_id": room_id}, + values={"display_name": display_name, "avatar_url": avatar_url}, + lock=False, # We're only inserter + ) + + if isinstance(self.database_engine, PostgresEngine): + # We weight the localpart most highly, then display name and finally + # server name + if new_entry: + sql = """ + INSERT INTO user_directory_search(user_id, vector) + VALUES (?, + setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + ) + """ + txn.execute( + sql, + ( + user_id, get_localpart_from_id(user_id), + get_domain_from_id(user_id), display_name, + ) + ) + else: + sql = """ + UPDATE user_directory_search + SET vector = setweight(to_tsvector('english', ?), 'A') + || setweight(to_tsvector('english', ?), 'D') + || setweight(to_tsvector('english', COALESCE(?, '')), 'B') + WHERE user_id = ? + """ + txn.execute( + sql, + ( + get_localpart_from_id(user_id), get_domain_from_id(user_id), + display_name, user_id, + ) + ) + elif isinstance(self.database_engine, Sqlite3Engine): + value = "%s %s" % (user_id, display_name,) if display_name else user_id + self._simple_upsert_txn( + txn, + table="user_directory_search", + keyvalues={"user_id": user_id}, + values={"value": value}, + lock=False, # We're only inserter + ) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) + + return self.runInteraction( + "update_profile_in_user_dir", _update_profile_in_user_dir_txn + ) + + @defer.inlineCallbacks + def update_user_in_public_user_list(self, user_id, room_id): + yield self._simple_update_one( + table="users_in_public_rooms", + keyvalues={"user_id": user_id}, + updatevalues={"room_id": room_id}, + desc="update_user_in_public_user_list", + ) + self.get_user_in_public_room.invalidate((user_id,)) + + def remove_from_user_dir(self, user_id): + def _remove_from_user_dir_txn(txn): + self._simple_delete_txn( + txn, + table="user_directory", + keyvalues={"user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="user_directory_search", + keyvalues={"user_id": user_id}, + ) + self._simple_delete_txn( + txn, + table="users_in_public_rooms", + keyvalues={"user_id": user_id}, + ) + txn.call_after( + self.get_user_in_directory.invalidate, (user_id,) + ) + txn.call_after( + self.get_user_in_public_room.invalidate, (user_id,) + ) + return self.runInteraction( + "remove_from_user_dir", _remove_from_user_dir_txn, + ) + + @defer.inlineCallbacks + def remove_from_user_in_public_room(self, user_id): + yield self._simple_delete( + table="users_in_public_rooms", + keyvalues={"user_id": user_id}, + desc="remove_from_user_in_public_room", + ) + self.get_user_in_public_room.invalidate((user_id,)) + + def get_users_in_public_due_to_room(self, room_id): + """Get all user_ids that are in the room directory becuase they're + in the given room_id + """ + return self._simple_select_onecol( + table="users_in_public_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_public_due_to_room", + ) + + @defer.inlineCallbacks + def get_users_in_dir_due_to_room(self, room_id): + """Get all user_ids that are in the room directory becuase they're + in the given room_id + """ + user_ids_dir = yield self._simple_select_onecol( + table="user_directory", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids_pub = yield self._simple_select_onecol( + table="users_in_public_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids_share = yield self._simple_select_onecol( + table="users_who_share_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids = set(user_ids_dir) + user_ids.update(user_ids_pub) + user_ids.update(user_ids_share) + + defer.returnValue(user_ids) + + @defer.inlineCallbacks + def get_all_rooms(self): + """Get all room_ids we've ever known about, in ascending order of "size" + """ + sql = """ + SELECT room_id FROM current_state_events + GROUP BY room_id + ORDER BY count(*) ASC + """ + rows = yield self._execute("get_all_rooms", None, sql) + defer.returnValue([room_id for room_id, in rows]) + + @defer.inlineCallbacks + def get_all_local_users(self): + """Get all local users + """ + sql = """ + SELECT name FROM users + """ + rows = yield self._execute("get_all_local_users", None, sql) + defer.returnValue([name for name, in rows]) + + def add_users_who_share_room(self, room_id, share_private, user_id_tuples): + """Insert entries into the users_who_share_rooms table. The first + user should be a local user. + + Args: + room_id (str) + share_private (bool): Is the room private + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + def _add_users_who_share_room_txn(txn): + self._simple_insert_many_txn( + txn, + table="users_who_share_rooms", + values=[ + { + "user_id": user_id, + "other_user_id": other_user_id, + "room_id": room_id, + "share_private": share_private, + } + for user_id, other_user_id in user_id_tuples + ], + ) + for user_id, other_user_id in user_id_tuples: + txn.call_after( + self.get_users_who_share_room_from_dir.invalidate, + (user_id,), + ) + txn.call_after( + self.get_if_users_share_a_room.invalidate, + (user_id, other_user_id), + ) + return self.runInteraction( + "add_users_who_share_room", _add_users_who_share_room_txn + ) + + def update_users_who_share_room(self, room_id, share_private, user_id_sets): + """Updates entries in the users_who_share_rooms table. The first + user should be a local user. + + Args: + room_id (str) + share_private (bool): Is the room private + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + def _update_users_who_share_room_txn(txn): + sql = """ + UPDATE users_who_share_rooms + SET room_id = ?, share_private = ? + WHERE user_id = ? AND other_user_id = ? + """ + txn.executemany( + sql, + ( + (room_id, share_private, uid, oid) + for uid, oid in user_id_sets + ) + ) + for user_id, other_user_id in user_id_sets: + txn.call_after( + self.get_users_who_share_room_from_dir.invalidate, + (user_id,), + ) + txn.call_after( + self.get_if_users_share_a_room.invalidate, + (user_id, other_user_id), + ) + return self.runInteraction( + "update_users_who_share_room", _update_users_who_share_room_txn + ) + + def remove_user_who_share_room(self, user_id, other_user_id): + """Deletes entries in the users_who_share_rooms table. The first + user should be a local user. + + Args: + room_id (str) + share_private (bool): Is the room private + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + def _remove_user_who_share_room_txn(txn): + self._simple_delete_txn( + txn, + table="users_who_share_rooms", + keyvalues={ + "user_id": user_id, + "other_user_id": other_user_id, + }, + ) + txn.call_after( + self.get_users_who_share_room_from_dir.invalidate, + (user_id,), + ) + txn.call_after( + self.get_if_users_share_a_room.invalidate, + (user_id, other_user_id), + ) + + return self.runInteraction( + "remove_user_who_share_room", _remove_user_who_share_room_txn + ) + + @cached(max_entries=500000) + def get_if_users_share_a_room(self, user_id, other_user_id): + """Gets if users share a room. + + Args: + user_id (str): Must be a local user_id + other_user_id (str) + + Returns: + bool|None: None if they don't share a room, otherwise whether they + share a private room or not. + """ + return self._simple_select_one_onecol( + table="users_who_share_rooms", + keyvalues={ + "user_id": user_id, + "other_user_id": other_user_id, + }, + retcol="share_private", + allow_none=True, + desc="get_if_users_share_a_room", + ) + + @cachedInlineCallbacks(max_entries=500000, iterable=True) + def get_users_who_share_room_from_dir(self, user_id): + """Returns the set of users who share a room with `user_id` + + Args: + user_id(str): Must be a local user + + Returns: + dict: user_id -> share_private mapping + """ + rows = yield self._simple_select_list( + table="users_who_share_rooms", + keyvalues={ + "user_id": user_id, + }, + retcols=("other_user_id", "share_private",), + desc="get_users_who_share_room_with_user", + ) + + defer.returnValue({ + row["other_user_id"]: row["share_private"] + for row in rows + }) + + def get_users_in_share_dir_with_room_id(self, user_id, room_id): + """Get all user tuples that are in the users_who_share_rooms due to the + given room_id. + + Returns: + [(user_id, other_user_id)]: where one of the two will match the given + user_id. + """ + sql = """ + SELECT user_id, other_user_id FROM users_who_share_rooms + WHERE room_id = ? AND (user_id = ? OR other_user_id = ?) + """ + return self._execute( + "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id + ) + + @defer.inlineCallbacks + def get_rooms_in_common_for_users(self, user_id, other_user_id): + """Given two user_ids find out the list of rooms they share. + """ + sql = """ + SELECT room_id FROM ( + SELECT c.room_id FROM current_state_events AS c + INNER JOIN room_memberships USING (event_id) + WHERE type = 'm.room.member' + AND membership = 'join' + AND state_key = ? + ) AS f1 INNER JOIN ( + SELECT c.room_id FROM current_state_events AS c + INNER JOIN room_memberships USING (event_id) + WHERE type = 'm.room.member' + AND membership = 'join' + AND state_key = ? + ) f2 USING (room_id) + """ + + rows = yield self._execute( + "get_rooms_in_common_for_users", None, sql, user_id, other_user_id + ) + + defer.returnValue([room_id for room_id, in rows]) + + def delete_all_from_user_dir(self): + """Delete the entire user directory + """ + def _delete_all_from_user_dir_txn(txn): + txn.execute("DELETE FROM user_directory") + txn.execute("DELETE FROM user_directory_search") + txn.execute("DELETE FROM users_in_public_rooms") + txn.execute("DELETE FROM users_who_share_rooms") + txn.call_after(self.get_user_in_directory.invalidate_all) + txn.call_after(self.get_user_in_public_room.invalidate_all) + txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all) + txn.call_after(self.get_if_users_share_a_room.invalidate_all) + return self.runInteraction( + "delete_all_from_user_dir", _delete_all_from_user_dir_txn + ) + + @cached() + def get_user_in_directory(self, user_id): + return self._simple_select_one( + table="user_directory", + keyvalues={"user_id": user_id}, + retcols=("room_id", "display_name", "avatar_url",), + allow_none=True, + desc="get_user_in_directory", + ) + + @cached() + def get_user_in_public_room(self, user_id): + return self._simple_select_one( + table="users_in_public_rooms", + keyvalues={"user_id": user_id}, + retcols=("room_id",), + allow_none=True, + desc="get_user_in_public_room", + ) + + def get_user_directory_stream_pos(self): + return self._simple_select_one_onecol( + table="user_directory_stream_pos", + keyvalues={}, + retcol="stream_id", + desc="get_user_directory_stream_pos", + ) + + def update_user_directory_stream_pos(self, stream_id): + return self._simple_update_one( + table="user_directory_stream_pos", + keyvalues={}, + updatevalues={"stream_id": stream_id}, + desc="update_user_directory_stream_pos", + ) + + def get_current_state_deltas(self, prev_stream_id): + prev_stream_id = int(prev_stream_id) + if not self._curr_state_delta_stream_cache.has_any_entity_changed(prev_stream_id): + return [] + + def get_current_state_deltas_txn(txn): + # First we calculate the max stream id that will give us less than + # N results. + # We arbitarily limit to 100 stream_id entries to ensure we don't + # select toooo many. + sql = """ + SELECT stream_id, count(*) + FROM current_state_delta_stream + WHERE stream_id > ? + GROUP BY stream_id + ORDER BY stream_id ASC + LIMIT 100 + """ + txn.execute(sql, (prev_stream_id,)) + + total = 0 + max_stream_id = prev_stream_id + for max_stream_id, count in txn: + total += count + if total > 100: + # We arbitarily limit to 100 entries to ensure we don't + # select toooo many. + break + + # Now actually get the deltas + sql = """ + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute(sql, (prev_stream_id, max_stream_id,)) + return self.cursor_to_dict(txn) + + return self.runInteraction( + "get_current_state_deltas", get_current_state_deltas_txn + ) + + def get_max_stream_id_in_current_state_deltas(self): + return self._simple_select_one_onecol( + table="current_state_delta_stream", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), -1)", + desc="get_max_stream_id_in_current_state_deltas", + ) + + @defer.inlineCallbacks + def search_user_dir(self, user_id, search_term, limit): + """Searches for users in directory + + Returns: + dict of the form:: + + { + "limited": <bool>, # whether there were more results or not + "results": [ # Ordered by best match first + { + "user_id": <user_id>, + "display_name": <display_name>, + "avatar_url": <avatar_url> + } + ] + } + """ + + if self.hs.config.user_directory_search_all_users: + # make s.user_id null to keep the ordering algorithm happy + join_clause = """ + CROSS JOIN (SELECT NULL as user_id) AS s + """ + join_args = () + where_clause = "1=1" + else: + join_clause = """ + LEFT JOIN users_in_public_rooms AS p USING (user_id) + LEFT JOIN ( + SELECT other_user_id AS user_id FROM users_who_share_rooms + WHERE user_id = ? AND share_private + ) AS s USING (user_id) + """ + join_args = (user_id,) + where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)" + + if isinstance(self.database_engine, PostgresEngine): + full_query, exact_query, prefix_query = _parse_query_postgres(search_term) + + # We order by rank and then if they have profile info + # The ranking algorithm is hand tweaked for "best" results. Broadly + # the idea is we give a higher weight to exact matches. + # The array of numbers are the weights for the various part of the + # search: (domain, _, display name, localpart) + sql = """ + SELECT d.user_id AS user_id, display_name, avatar_url + FROM user_directory_search + INNER JOIN user_directory AS d USING (user_id) + %s + WHERE + %s + AND vector @@ to_tsquery('english', ?) + ORDER BY + (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) + * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END) + * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END) + * ( + 3 * ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + + ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + ) + DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ % (join_clause, where_clause) + args = join_args + (full_query, exact_query, prefix_query, limit + 1,) + elif isinstance(self.database_engine, Sqlite3Engine): + search_query = _parse_query_sqlite(search_term) + + sql = """ + SELECT d.user_id AS user_id, display_name, avatar_url + FROM user_directory_search + INNER JOIN user_directory AS d USING (user_id) + %s + WHERE + %s + AND value MATCH ? + ORDER BY + rank(matchinfo(user_directory_search)) DESC, + display_name IS NULL, + avatar_url IS NULL + LIMIT ? + """ % (join_clause, where_clause) + args = join_args + (search_query, limit + 1) + else: + # This should be unreachable. + raise Exception("Unrecognized database engine") + + results = yield self._execute( + "search_user_dir", self.cursor_to_dict, sql, *args + ) + + limited = len(results) > limit + + defer.returnValue({ + "limited": limited, + "results": results, + }) + + +def _parse_query_sqlite(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + + We specifically add both a prefix and non prefix matching term so that + exact matches get ranked higher. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + return " & ".join("(%s* OR %s)" % (result, result,) for result in results) + + +def _parse_query_postgres(search_term): + """Takes a plain unicode string from the user and converts it into a form + that can be passed to database. + We use this so that we can add prefix matching, which isn't something + that is supported by default. + """ + + # Pull out the individual words, discarding any non-word characters. + results = re.findall(r"([\w\-]+)", search_term, re.UNICODE) + + both = " & ".join("(%s:* | %s)" % (result, result,) for result in results) + exact = " & ".join("%s" % (result,) for result in results) + prefix = " & ".join("%s:*" % (result,) for result in results) + + return both, exact, prefix diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 46cf93ff87..95031dc9ec 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -30,6 +30,17 @@ class IdGenerator(object): def _load_current_id(db_conn, table, column, step=1): + """ + + Args: + db_conn (object): + table (str): + column (str): + step (int): + + Returns: + int + """ cur = db_conn.cursor() if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) @@ -131,6 +142,9 @@ class StreamIdGenerator(object): def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. + + Returns: + int """ with self._lock: if self._unfinished_ids: |