diff options
Diffstat (limited to 'synapse/storage')
28 files changed, 2120 insertions, 521 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e8495f1eb9..d604e7668f 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore, self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) + self._device_list_id_gen = StreamIdGenerator( + db_conn, "device_lists_stream", "stream_id", + ) self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") @@ -210,6 +213,14 @@ class DataStore(RoomMemberStore, RoomStore, prefilled_cache=device_outbox_prefill, ) + device_list_max = self._device_list_id_gen.get_current_token() + self._device_list_stream_cache = StreamChangeCache( + "DeviceListStreamChangeCache", device_list_max, + ) + self._device_list_federation_stream_cache = StreamChangeCache( + "DeviceListFederationStreamChangeCache", device_list_max, + ) + cur = LoggingTransaction( db_conn.cursor(), name="_find_stream_orderings_for_times_txn", @@ -286,6 +297,82 @@ class DataStore(RoomMemberStore, RoomStore, desc="get_user_ip_and_agents", ) + def get_users(self): + """Function to reterive a list of users in users table. + + Args: + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self._simple_select_list( + table="users", + keyvalues={}, + retcols=[ + "name", + "password_hash", + "is_guest", + "admin" + ], + desc="get_users", + ) + + def get_users_paginate(self, order, start, limit): + """Function to reterive a paginated list of users from + users list. This will return a json object, which contains + list of users and the total number of users in users table. + + Args: + order (str): column name to order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + Returns: + defer.Deferred: resolves to json object {list[dict[str, Any]], count} + """ + is_guest = 0 + i_start = (int)(start) + i_limit = (int)(limit) + return self.get_user_list_paginate( + table="users", + keyvalues={ + "is_guest": is_guest + }, + pagevalues=[ + order, + i_limit, + i_start + ], + retcols=[ + "name", + "password_hash", + "is_guest", + "admin" + ], + desc="get_users_paginate", + ) + + def search_users(self, term): + """Function to search users list for one or more users with + the matched term. + + Args: + term (str): search term + col (str): column to query term should be matched to + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self._simple_search_list( + table="users", + term=term, + col="name", + retcols=[ + "name", + "password_hash", + "is_guest", + "admin" + ], + desc="search_users", + ) + def are_all_users_on_domain(txn, database_engine, domain): sql = database_engine.convert_param_style( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 963ef999d5..c659004e8d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -18,7 +18,6 @@ from synapse.api.errors import StoreError from synapse.util.logcontext import LoggingContext, PreserveLoggingContext from synapse.util.caches.dictionary_cache import DictionaryCache from synapse.util.caches.descriptors import Cache -from synapse.util.caches import intern_dict from synapse.storage.engines import PostgresEngine import synapse.metrics @@ -74,13 +73,22 @@ class LoggingTransaction(object): def __setattr__(self, name, value): setattr(self.txn, name, value) + def __iter__(self): + return self.txn.__iter__() + def execute(self, sql, *args): self._do_execute(self.txn.execute, sql, *args) def executemany(self, sql, *args): self._do_execute(self.txn.executemany, sql, *args) + def _make_sql_one_line(self, sql): + "Strip newlines out of SQL so that the loggers in the DB are on one line" + return " ".join(l.strip() for l in sql.splitlines() if l.strip()) + def _do_execute(self, func, sql, *args): + sql = self._make_sql_one_line(sql) + # TODO(paul): Maybe use 'info' and 'debug' for values? sql_logger.debug("[SQL] {%s} %s", self.name, sql) @@ -127,7 +135,7 @@ class PerformanceCounters(object): def interval(self, interval_duration, limit=3): counters = [] - for name, (count, cum_time) in self.current_counters.items(): + for name, (count, cum_time) in self.current_counters.iteritems(): prev_count, prev_time = self.previous_counters.get(name, (0, 0)) counters.append(( (cum_time - prev_time) / interval_duration, @@ -350,9 +358,9 @@ class SQLBaseStore(object): Returns: A list of dicts where the key is the column header. """ - col_headers = list(column[0] for column in cursor.description) + col_headers = list(intern(column[0]) for column in cursor.description) results = list( - intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall() + dict(zip(col_headers, row)) for row in cursor ) return results @@ -387,6 +395,10 @@ class SQLBaseStore(object): Args: table : string giving the table name values : dict of new column names and values for them + + Returns: + bool: Whether the row was inserted or not. Only useful when + `or_ignore` is True """ try: yield self.runInteraction( @@ -398,6 +410,8 @@ class SQLBaseStore(object): # a cursor after we receive an error from the db. if not or_ignore: raise + defer.returnValue(False) + defer.returnValue(True) @staticmethod def _simple_insert_txn(txn, table, values): @@ -477,10 +491,6 @@ class SQLBaseStore(object): " AND ".join("%s = ?" % (k,) for k in keyvalues) ) sqlargs = values.values() + keyvalues.values() - logger.debug( - "[SQL] %s Args=%s", - sql, sqlargs, - ) txn.execute(sql, sqlargs) if txn.rowcount == 0: @@ -495,10 +505,6 @@ class SQLBaseStore(object): ", ".join(k for k in allvalues), ", ".join("?" for _ in allvalues) ) - logger.debug( - "[SQL] %s Args=%s", - sql, keyvalues.values(), - ) txn.execute(sql, allvalues.values()) return True @@ -562,7 +568,7 @@ class SQLBaseStore(object): @staticmethod def _simple_select_onecol_txn(txn, table, keyvalues, retcol): if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) else: where = "" @@ -576,7 +582,7 @@ class SQLBaseStore(object): txn.execute(sql, keyvalues.values()) - return [r[0] for r in txn.fetchall()] + return [r[0] for r in txn] def _simple_select_onecol(self, table, keyvalues, retcol, desc="_simple_select_onecol"): @@ -709,7 +715,7 @@ class SQLBaseStore(object): ) values.extend(iterable) - for key, value in keyvalues.items(): + for key, value in keyvalues.iteritems(): clauses.append("%s = ?" % (key,)) values.append(value) @@ -750,7 +756,7 @@ class SQLBaseStore(object): @staticmethod def _simple_update_one_txn(txn, table, keyvalues, updatevalues): if keyvalues: - where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) + where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys()) else: where = "" @@ -837,6 +843,47 @@ class SQLBaseStore(object): return txn.execute(sql, keyvalues.values()) + def _simple_delete_many(self, table, column, iterable, keyvalues, desc): + return self.runInteraction( + desc, self._simple_delete_many_txn, table, column, iterable, keyvalues + ) + + @staticmethod + def _simple_delete_many_txn(txn, table, column, iterable, keyvalues): + """Executes a DELETE query on the named table. + + Filters rows by if value of `column` is in `iterable`. + + Args: + txn : Transaction object + table : string giving the table name + column : column name to test for inclusion against `iterable` + iterable : list + keyvalues : dict of column names and values to select the rows with + """ + if not iterable: + return + + sql = "DELETE FROM %s" % table + + clauses = [] + values = [] + clauses.append( + "%s IN (%s)" % (column, ",".join("?" for _ in iterable)) + ) + values.extend(iterable) + + for key, value in keyvalues.iteritems(): + clauses.append("%s = ?" % (key,)) + values.append(value) + + if clauses: + sql = "%s WHERE %s" % ( + sql, + " AND ".join(clauses), + ) + return txn.execute(sql, values) + def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value, limit=100000): # Fetch a mapping of room_id -> max stream position for "recent" rooms. @@ -857,16 +904,16 @@ class SQLBaseStore(object): txn = db_conn.cursor() txn.execute(sql, (int(max_value),)) - rows = txn.fetchall() - txn.close() cache = { row[0]: int(row[1]) - for row in rows + for row in txn } + txn.close() + if cache: - min_val = min(cache.values()) + min_val = min(cache.itervalues()) else: min_val = max_value @@ -928,6 +975,165 @@ class SQLBaseStore(object): else: return 0 + def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols, + desc="_simple_select_list_paginate"): + """Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + retcols (iterable[str]): the names of the columns to return + order (str): order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + """ + return self.runInteraction( + desc, + self._simple_select_list_paginate_txn, + table, keyvalues, pagevalues, retcols + ) + + @classmethod + def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols): + """Executes a SELECT query on the named table with start and limit, + of row numbers, which may return zero or number of rows from start to limit, + returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + keyvalues (dict[str, T] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + pagevalues ([]): + order (str): order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] + + """ + if keyvalues: + sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k,) for k in keyvalues), + " ? ASC LIMIT ? OFFSET ?" + ) + txn.execute(sql, keyvalues.values() + pagevalues) + else: + sql = "SELECT %s FROM %s ORDER BY %s" % ( + ", ".join(retcols), + table, + " ? ASC LIMIT ? OFFSET ?" + ) + txn.execute(sql, pagevalues) + + return cls.cursor_to_dict(txn) + + @defer.inlineCallbacks + def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols, + desc="get_user_list_paginate"): + """Get a list of users from start row to a limit number of rows. This will + return a json object with users and total number of users in users list. + + Args: + table (str): the table name + keyvalues (dict[str, Any] | None): + column names and values to select the rows with, or None to not + apply a WHERE clause. + pagevalues ([]): + order (str): order the select by this column + start (int): start number to begin the query from + limit (int): number of rows to reterive + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to json object {list[dict[str, Any]], count} + """ + users = yield self.runInteraction( + desc, + self._simple_select_list_paginate_txn, + table, keyvalues, pagevalues, retcols + ) + count = yield self.runInteraction( + desc, + self.get_user_count_txn + ) + retval = { + "users": users, + "total": count + } + defer.returnValue(retval) + + def get_user_count_txn(self, txn): + """Get a total number of registerd users in the users list. + + Args: + txn : Transaction object + Returns: + defer.Deferred: resolves to int + """ + sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;" + txn.execute(sql_count) + count = txn.fetchone()[0] + defer.returnValue(count) + + def _simple_search_list(self, table, term, col, retcols, + desc="_simple_search_list"): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + + return self.runInteraction( + desc, + self._simple_search_list_txn, + table, term, col, retcols + ) + + @classmethod + def _simple_search_list_txn(cls, txn, table, term, col, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + txn : Transaction object + table (str): the table name + term (str | None): + term for searching the table matched to a column. + col (str): column to query term should be matched to + retcols (iterable[str]): the names of the columns to return + Returns: + defer.Deferred: resolves to list[dict[str, Any]] or None + """ + if term: + sql = "SELECT %s FROM %s WHERE %s LIKE ?" % ( + ", ".join(retcols), + table, + col + ) + termvalues = ["%%" + term + "%%"] + txn.execute(sql, termvalues) + else: + return 0 + + return cls.cursor_to_dict(txn) + class _RollbackButIsFineException(Exception): """ This exception is used to rollback a transaction without implying diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py index 3fa226e92d..aa84ffc2b0 100644 --- a/synapse/storage/account_data.py +++ b/synapse/storage/account_data.py @@ -182,7 +182,7 @@ class AccountDataStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) global_account_data = { - row[0]: json.loads(row[1]) for row in txn.fetchall() + row[0]: json.loads(row[1]) for row in txn } sql = ( @@ -193,7 +193,7 @@ class AccountDataStore(SQLBaseStore): txn.execute(sql, (user_id, stream_id)) account_data_by_room = {} - for row in txn.fetchall(): + for row in txn: room_account_data = account_data_by_room.setdefault(row[0], {}) room_account_data[row[1]] = json.loads(row[2]) diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 94b2bcc54a..813ad59e56 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import synapse.util.async from ._base import SQLBaseStore from . import engines @@ -84,24 +85,14 @@ class BackgroundUpdateStore(SQLBaseStore): self._background_update_performance = {} self._background_update_queue = [] self._background_update_handlers = {} - self._background_update_timer = None @defer.inlineCallbacks def start_doing_background_updates(self): - assert self._background_update_timer is None, \ - "background updates already running" - logger.info("Starting background schema updates") while True: - sleep = defer.Deferred() - self._background_update_timer = self._clock.call_later( - self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None - ) - try: - yield sleep - finally: - self._background_update_timer = None + yield synapse.util.async.sleep( + self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.) try: result = yield self.do_next_background_update( diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index bde3b5cbbc..2714519d21 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -20,6 +20,8 @@ from twisted.internet import defer from .background_updates import BackgroundUpdateStore +from synapse.util.caches.expiringcache import ExpiringCache + logger = logging.getLogger(__name__) @@ -42,6 +44,15 @@ class DeviceInboxStore(BackgroundUpdateStore): self._background_drop_index_device_inbox, ) + # Map of (user_id, device_id) to the last stream_id that has been + # deleted up to. This is so that we can no op deletions. + self._last_device_delete_cache = ExpiringCache( + cache_name="last_device_delete_cache", + clock=self._clock, + max_len=10000, + expiry_ms=30 * 60 * 1000, + ) + @defer.inlineCallbacks def add_messages_to_device_inbox(self, local_messages_by_user_then_device, remote_messages_by_destination): @@ -167,7 +178,7 @@ class DeviceInboxStore(BackgroundUpdateStore): ) txn.execute(sql, (user_id,)) message_json = ujson.dumps(messages_by_device["*"]) - for row in txn.fetchall(): + for row in txn: # Add the message for all devices for this user on this # server. device = row[0] @@ -184,7 +195,7 @@ class DeviceInboxStore(BackgroundUpdateStore): # TODO: Maybe this needs to be done in batches if there are # too many local devices for a given user. txn.execute(sql, [user_id] + devices) - for row in txn.fetchall(): + for row in txn: # Only insert into the local inbox if the device exists on # this server device = row[0] @@ -240,7 +251,7 @@ class DeviceInboxStore(BackgroundUpdateStore): user_id, device_id, last_stream_id, current_stream_id, limit )) messages = [] - for row in txn.fetchall(): + for row in txn: stream_pos = row[0] messages.append(ujson.loads(row[1])) if len(messages) < limit: @@ -251,6 +262,7 @@ class DeviceInboxStore(BackgroundUpdateStore): "get_new_messages_for_device", get_new_messages_for_device_txn, ) + @defer.inlineCallbacks def delete_messages_for_device(self, user_id, device_id, up_to_stream_id): """ Args: @@ -260,6 +272,18 @@ class DeviceInboxStore(BackgroundUpdateStore): Returns: A deferred that resolves to the number of messages deleted. """ + # If we have cached the last stream id we've deleted up to, we can + # check if there is likely to be anything that needs deleting + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), None + ) + if last_deleted_stream_id: + has_changed = self._device_inbox_stream_cache.has_entity_changed( + user_id, last_deleted_stream_id + ) + if not has_changed: + defer.returnValue(0) + def delete_messages_for_device_txn(txn): sql = ( "DELETE FROM device_inbox" @@ -269,10 +293,20 @@ class DeviceInboxStore(BackgroundUpdateStore): txn.execute(sql, (user_id, device_id, up_to_stream_id)) return txn.rowcount - return self.runInteraction( + count = yield self.runInteraction( "delete_messages_for_device", delete_messages_for_device_txn ) + # Update the cache, ensuring that we only ever increase the value + last_deleted_stream_id = self._last_device_delete_cache.get( + (user_id, device_id), 0 + ) + self._last_device_delete_cache[(user_id, device_id)] = max( + last_deleted_stream_id, up_to_stream_id + ) + + defer.returnValue(count) + def get_all_new_device_messages(self, last_pos, current_pos, limit): """ Args: @@ -306,7 +340,7 @@ class DeviceInboxStore(BackgroundUpdateStore): " ORDER BY stream_id ASC" ) txn.execute(sql, (last_pos, upper_pos)) - rows.extend(txn.fetchall()) + rows.extend(txn) return rows @@ -323,12 +357,12 @@ class DeviceInboxStore(BackgroundUpdateStore): """ Args: destination(str): The name of the remote server. - last_stream_id(int): The last position of the device message stream + last_stream_id(int|long): The last position of the device message stream that the server sent up to. - current_stream_id(int): The current position of the device + current_stream_id(int|long): The current position of the device message stream. Returns: - Deferred ([dict], int): List of messages for the device and where + Deferred ([dict], int|long): List of messages for the device and where in the stream the messages got to. """ @@ -350,7 +384,7 @@ class DeviceInboxStore(BackgroundUpdateStore): destination, last_stream_id, current_stream_id, limit )) messages = [] - for row in txn.fetchall(): + for row in txn: stream_pos = row[0] messages.append(ujson.loads(row[1])) if len(messages) < limit: diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index 17920d4480..53e36791d5 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -13,37 +13,49 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import ujson as json from twisted.internet import defer from synapse.api.errors import StoreError from ._base import SQLBaseStore +from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks + logger = logging.getLogger(__name__) class DeviceStore(SQLBaseStore): + def __init__(self, hs): + super(DeviceStore, self).__init__(hs) + + self._clock.looping_call( + self._prune_old_outbound_device_pokes, 60 * 60 * 1000 + ) + + self.register_background_index_update( + "device_lists_stream_idx", + index_name="device_lists_stream_user_id", + table="device_lists_stream", + columns=["user_id", "device_id"], + ) + @defer.inlineCallbacks def store_device(self, user_id, device_id, - initial_device_display_name, - ignore_if_known=True): + initial_device_display_name): """Ensure the given device is known; add it to the store if not Args: user_id (str): id of user associated with the device device_id (str): id of device initial_device_display_name (str): initial displayname of the - device - ignore_if_known (bool): ignore integrity errors which mean the - device is already known + device. Ignored if device exists. Returns: - defer.Deferred - Raises: - StoreError: if ignore_if_known is False and the device was already - known + defer.Deferred: boolean whether the device was inserted or an + existing device existed with that ID. """ try: - yield self._simple_insert( + inserted = yield self._simple_insert( "devices", values={ "user_id": user_id, @@ -51,8 +63,9 @@ class DeviceStore(SQLBaseStore): "display_name": initial_device_display_name }, desc="store_device", - or_ignore=ignore_if_known, + or_ignore=True, ) + defer.returnValue(inserted) except Exception as e: logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" " display_name=%s(%r) failed: %s", @@ -95,6 +108,23 @@ class DeviceStore(SQLBaseStore): desc="delete_device", ) + def delete_devices(self, user_id, device_ids): + """Deletes several devices. + + Args: + user_id (str): The ID of the user which owns the devices + device_ids (list): The IDs of the devices to delete + Returns: + defer.Deferred + """ + return self._simple_delete_many( + table="devices", + column="device_id", + iterable=device_ids, + keyvalues={"user_id": user_id}, + desc="delete_devices", + ) + def update_device(self, user_id, device_id, new_display_name=None): """Update a device. @@ -139,3 +169,490 @@ class DeviceStore(SQLBaseStore): ) defer.returnValue({d["device_id"]: d for d in devices}) + + @cached(max_entries=10000) + def get_device_list_last_stream_id_for_remote(self, user_id): + """Get the last stream_id we got for a user. May be None if we haven't + got any information for them. + """ + return self._simple_select_one_onecol( + table="device_lists_remote_extremeties", + keyvalues={"user_id": user_id}, + retcol="stream_id", + desc="get_device_list_remote_extremity", + allow_none=True, + ) + + @cachedList(cached_method_name="get_device_list_last_stream_id_for_remote", + list_name="user_ids", inlineCallbacks=True) + def get_device_list_last_stream_id_for_remotes(self, user_ids): + rows = yield self._simple_select_many_batch( + table="device_lists_remote_extremeties", + column="user_id", + iterable=user_ids, + retcols=("user_id", "stream_id",), + desc="get_user_devices_from_cache", + ) + + results = {user_id: None for user_id in user_ids} + results.update({ + row["user_id"]: row["stream_id"] for row in rows + }) + + defer.returnValue(results) + + @defer.inlineCallbacks + def mark_remote_user_device_list_as_unsubscribed(self, user_id): + """Mark that we no longer track device lists for remote user. + """ + yield self._simple_delete( + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + desc="mark_remote_user_device_list_as_unsubscribed", + ) + self.get_device_list_last_stream_id_for_remote.invalidate((user_id,)) + + def update_remote_device_list_cache_entry(self, user_id, device_id, content, + stream_id): + """Updates a single user's device in the cache. + """ + return self.runInteraction( + "update_remote_device_list_cache_entry", + self._update_remote_device_list_cache_entry_txn, + user_id, device_id, content, stream_id, + ) + + def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id, + content, stream_id): + self._simple_upsert_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "content": json.dumps(content), + } + ) + + txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,)) + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + + def update_remote_device_list_cache(self, user_id, devices, stream_id): + """Replace the cache of the remote user's devices. + """ + return self.runInteraction( + "update_remote_device_list_cache", + self._update_remote_device_list_cache_txn, + user_id, devices, stream_id, + ) + + def _update_remote_device_list_cache_txn(self, txn, user_id, devices, + stream_id): + self._simple_delete_txn( + txn, + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_remote_cache", + values=[ + { + "user_id": user_id, + "device_id": content["device_id"], + "content": json.dumps(content), + } + for content in devices + ] + ) + + txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after( + self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) + ) + + self._simple_upsert_txn( + txn, + table="device_lists_remote_extremeties", + keyvalues={ + "user_id": user_id, + }, + values={ + "stream_id": stream_id, + } + ) + + def get_devices_by_remote(self, destination, from_stream_id): + """Get stream of updates to send to remote servers + + Returns: + (int, list[dict]): current stream id and list of updates + """ + now_stream_id = self._device_list_id_gen.get_current_token() + + has_changed = self._device_list_federation_stream_cache.has_entity_changed( + destination, int(from_stream_id) + ) + if not has_changed: + return (now_stream_id, []) + + return self.runInteraction( + "get_devices_by_remote", self._get_devices_by_remote_txn, + destination, from_stream_id, now_stream_id, + ) + + def _get_devices_by_remote_txn(self, txn, destination, from_stream_id, + now_stream_id): + sql = """ + SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes + WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ? + GROUP BY user_id, device_id + LIMIT 20 + """ + txn.execute( + sql, (destination, from_stream_id, now_stream_id, False) + ) + + # maps (user_id, device_id) -> stream_id + query_map = {(r[0], r[1]): r[2] for r in txn} + if not query_map: + return (now_stream_id, []) + + if len(query_map) >= 20: + now_stream_id = max(stream_id for stream_id in query_map.itervalues()) + + devices = self._get_e2e_device_keys_txn( + txn, query_map.keys(), include_all_devices=True + ) + + prev_sent_id_sql = """ + SELECT coalesce(max(stream_id), 0) as stream_id + FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND stream_id <= ? + """ + + results = [] + for user_id, user_devices in devices.iteritems(): + # The prev_id for the first row is always the last row before + # `from_stream_id` + txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id)) + rows = txn.fetchall() + prev_id = rows[0][0] + for device_id, device in user_devices.iteritems(): + stream_id = query_map[(user_id, device_id)] + result = { + "user_id": user_id, + "device_id": device_id, + "prev_id": [prev_id] if prev_id else [], + "stream_id": stream_id, + } + + prev_id = stream_id + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return (now_stream_id, results) + + @defer.inlineCallbacks + def get_user_devices_from_cache(self, query_list): + """Get the devices (and keys if any) for remote users from the cache. + + Args: + query_list(list): List of (user_id, device_ids), if device_ids is + falsey then return all device ids for that user. + + Returns: + (user_ids_not_in_cache, results_map), where user_ids_not_in_cache is + a set of user_ids and results_map is a mapping of + user_id -> device_id -> device_info + """ + user_ids = set(user_id for user_id, _ in query_list) + user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids)) + user_ids_in_cache = set( + user_id for user_id, stream_id in user_map.items() if stream_id + ) + user_ids_not_in_cache = user_ids - user_ids_in_cache + + results = {} + for user_id, device_id in query_list: + if user_id not in user_ids_in_cache: + continue + + if device_id: + device = yield self._get_cached_user_device(user_id, device_id) + results.setdefault(user_id, {})[device_id] = device + else: + results[user_id] = yield self._get_cached_devices_for_user(user_id) + + defer.returnValue((user_ids_not_in_cache, results)) + + @cachedInlineCallbacks(num_args=2, tree=True) + def _get_cached_user_device(self, user_id, device_id): + content = yield self._simple_select_one_onecol( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="content", + desc="_get_cached_user_device", + ) + defer.returnValue(json.loads(content)) + + @cachedInlineCallbacks() + def _get_cached_devices_for_user(self, user_id): + devices = yield self._simple_select_list( + table="device_lists_remote_cache", + keyvalues={ + "user_id": user_id, + }, + retcols=("device_id", "content"), + desc="_get_cached_devices_for_user", + ) + defer.returnValue({ + device["device_id"]: json.loads(device["content"]) + for device in devices + }) + + def get_devices_with_keys_by_user(self, user_id): + """Get all devices (with any device keys) for a user + + Returns: + (stream_id, devices) + """ + return self.runInteraction( + "get_devices_with_keys_by_user", + self._get_devices_with_keys_by_user_txn, user_id, + ) + + def _get_devices_with_keys_by_user_txn(self, txn, user_id): + now_stream_id = self._device_list_id_gen.get_current_token() + + devices = self._get_e2e_device_keys_txn( + txn, [(user_id, None)], include_all_devices=True + ) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.iteritems(): + result = { + "device_id": device_id, + } + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = json.loads(key_json) + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + + def mark_as_sent_devices_by_remote(self, destination, stream_id): + """Mark that updates have successfully been sent to the destination. + """ + return self.runInteraction( + "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, + destination, stream_id, + ) + + def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): + # First we DELETE all rows such that only the latest row for each + # (destination, user_id is left. We do this by selecting first and + # deleting. + sql = """ + SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes + WHERE destination = ? AND stream_id <= ? + GROUP BY user_id + HAVING count(*) > 1 + """ + txn.execute(sql, (destination, stream_id,)) + rows = txn.fetchall() + + sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE destination = ? AND user_id = ? AND stream_id < ? + """ + txn.executemany( + sql, ((destination, row[0], row[1],) for row in rows) + ) + + # Mark everything that is left as sent + sql = """ + UPDATE device_lists_outbound_pokes SET sent = ? + WHERE destination = ? AND stream_id <= ? + """ + txn.execute(sql, (True, destination, stream_id,)) + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + """Get set of users whose devices have changed since `from_key`. + """ + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) + + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ? + """ + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row[0] for row in rows)) + + def get_all_device_list_changes_for_remotes(self, from_key): + """Return a list of `(stream_id, user_id, destination)` which is the + combined list of changes to devices, and which destinations need to be + poked. `destination` may be None if no destinations need to be poked. + """ + sql = """ + SELECT stream_id, user_id, destination FROM device_lists_stream + LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id) + WHERE stream_id > ? + """ + return self._execute( + "get_all_device_list_changes_for_remotes", None, + sql, from_key, + ) + + @defer.inlineCallbacks + def add_device_change_to_streams(self, user_id, device_ids, hosts): + """Persist that a user's devices have been updated, and which hosts + (if any) should be poked. + """ + with self._device_list_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_device_change_to_streams", self._add_device_change_txn, + user_id, device_ids, hosts, stream_id, + ) + defer.returnValue(stream_id) + + def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id): + now = self._clock.time_msec() + + txn.call_after( + self._device_list_stream_cache.entity_has_changed, + user_id, stream_id, + ) + for host in hosts: + txn.call_after( + self._device_list_federation_stream_cache.entity_has_changed, + host, stream_id, + ) + + # Delete older entries in the table, as we really only care about + # when the latest change happened. + txn.executemany( + """ + DELETE FROM device_lists_stream + WHERE user_id = ? AND device_id = ? AND stream_id < ? + """, + [(user_id, device_id, stream_id) for device_id in device_ids] + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_stream", + values=[ + { + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + } + for device_id in device_ids + ] + ) + + self._simple_insert_many_txn( + txn, + table="device_lists_outbound_pokes", + values=[ + { + "destination": destination, + "stream_id": stream_id, + "user_id": user_id, + "device_id": device_id, + "sent": False, + "ts": now, + } + for destination in hosts + for device_id in device_ids + ] + ) + + def get_device_stream_token(self): + return self._device_list_id_gen.get_current_token() + + def _prune_old_outbound_device_pokes(self): + """Delete old entries out of the device_lists_outbound_pokes to ensure + that we don't fill up due to dead servers. We keep one entry per + (destination, user_id) tuple to ensure that the prev_ids remain correct + if the server does come back. + """ + yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000 + + def _prune_txn(txn): + select_sql = """ + SELECT destination, user_id, max(stream_id) as stream_id + FROM device_lists_outbound_pokes + GROUP BY destination, user_id + HAVING min(ts) < ? AND count(*) > 1 + """ + + txn.execute(select_sql, (yesterday,)) + rows = txn.fetchall() + + if not rows: + return + + delete_sql = """ + DELETE FROM device_lists_outbound_pokes + WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ? + """ + + txn.executemany( + delete_sql, + ( + (yesterday, row[0], row[1], row[2]) + for row in rows + ) + ) + + logger.info("Pruned %d device list outbound pokes", txn.rowcount) + + return self.runInteraction( + "_prune_old_outbound_device_pokes", _prune_txn + ) diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py index 385d607056..7cbc1470fd 100644 --- a/synapse/storage/end_to_end_keys.py +++ b/synapse/storage/end_to_end_keys.py @@ -12,95 +12,173 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import collections +from twisted.internet import defer -import twisted.internet.defer +from synapse.api.errors import SynapseError + +from canonicaljson import encode_canonical_json +import ujson as json from ._base import SQLBaseStore class EndToEndKeyStore(SQLBaseStore): - def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes): - return self._simple_upsert( - table="e2e_device_keys_json", - keyvalues={ - "user_id": user_id, - "device_id": device_id, - }, - values={ - "ts_added_ms": time_now, - "key_json": json_bytes, - } + def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. + """ + def _set_e2e_device_keys_txn(txn): + old_key_json = self._simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + retcol="key_json", + allow_none=True, + ) + + new_key_json = encode_canonical_json(device_keys) + if old_key_json == new_key_json: + return False + + self._simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + values={ + "ts_added_ms": time_now, + "key_json": new_key_json, + } + ) + + return True + + return self.runInteraction( + "set_e2e_device_keys", _set_e2e_device_keys_txn ) - def get_e2e_device_keys(self, query_list): + @defer.inlineCallbacks + def get_e2e_device_keys(self, query_list, include_all_devices=False): """Fetch a list of device keys. Args: query_list(list): List of pairs of user_ids and device_ids. + include_all_devices (bool): whether to include entries for devices + that don't have device keys Returns: Dict mapping from user-id to dict mapping from device_id to dict containing "key_json", "device_display_name". """ if not query_list: - return {} + defer.returnValue({}) - return self.runInteraction( - "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list + results = yield self.runInteraction( + "get_e2e_device_keys", self._get_e2e_device_keys_txn, + query_list, include_all_devices, ) - def _get_e2e_device_keys_txn(self, txn, query_list): + for user_id, device_keys in results.iteritems(): + for device_id, device_info in device_keys.iteritems(): + device_info["keys"] = json.loads(device_info.pop("key_json")) + + defer.returnValue(results) + + def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices): query_clauses = [] query_params = [] for (user_id, device_id) in query_list: - query_clause = "k.user_id = ?" + query_clause = "user_id = ?" query_params.append(user_id) - if device_id: - query_clause += " AND k.device_id = ?" + if device_id is not None: + query_clause += " AND device_id = ?" query_params.append(device_id) query_clauses.append(query_clause) sql = ( - "SELECT k.user_id, k.device_id, " + "SELECT user_id, device_id, " " d.display_name AS device_display_name, " " k.key_json" - " FROM e2e_device_keys_json k" - " LEFT JOIN devices d ON d.user_id = k.user_id" - " AND d.device_id = k.device_id" + " FROM devices d" + " %s JOIN e2e_device_keys_json k USING (user_id, device_id)" " WHERE %s" ) % ( + "LEFT" if include_all_devices else "INNER", " OR ".join("(" + q + ")" for q in query_clauses) ) txn.execute(sql, query_params) rows = self.cursor_to_dict(txn) - result = collections.defaultdict(dict) + result = {} for row in rows: - result[row["user_id"]][row["device_id"]] = row + result.setdefault(row["user_id"], {})[row["device_id"]] = row return result + @defer.inlineCallbacks def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): + """Insert some new one time keys for a device. + + Checks if any of the keys are already inserted, if they are then check + if they match. If they don't then we raise an error. + """ + + # First we check if we have already persisted any of the keys. + rows = yield self._simple_select_many_batch( + table="e2e_one_time_keys_json", + column="key_id", + iterable=[key_id for _, key_id, _ in key_list], + retcols=("algorithm", "key_id", "key_json",), + keyvalues={ + "user_id": user_id, + "device_id": device_id, + }, + desc="add_e2e_one_time_keys_check", + ) + + existing_key_map = { + (row["algorithm"], row["key_id"]): row["key_json"] for row in rows + } + + new_keys = [] # Keys that we need to insert + for algorithm, key_id, json_bytes in key_list: + ex_bytes = existing_key_map.get((algorithm, key_id), None) + if ex_bytes: + if json_bytes != ex_bytes: + raise SynapseError( + 400, "One time key with key_id %r already exists" % (key_id,) + ) + else: + new_keys.append((algorithm, key_id, json_bytes)) + def _add_e2e_one_time_keys(txn): - for (algorithm, key_id, json_bytes) in key_list: - self._simple_upsert_txn( - txn, table="e2e_one_time_keys_json", - keyvalues={ + # We are protected from race between lookup and insertion due to + # a unique constraint. If there is a race of two calls to + # `add_e2e_one_time_keys` then they'll conflict and we will only + # insert one set. + self._simple_insert_many_txn( + txn, table="e2e_one_time_keys_json", + values=[ + { "user_id": user_id, "device_id": device_id, "algorithm": algorithm, "key_id": key_id, - }, - values={ "ts_added_ms": time_now, "key_json": json_bytes, } - ) - return self.runInteraction( - "add_e2e_one_time_keys", _add_e2e_one_time_keys + for algorithm, key_id, json_bytes in new_keys + ], + ) + yield self.runInteraction( + "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys ) def count_e2e_one_time_keys(self, user_id, device_id): @@ -116,7 +194,7 @@ class EndToEndKeyStore(SQLBaseStore): ) txn.execute(sql, (user_id, device_id)) result = {} - for algorithm, key_count in txn.fetchall(): + for algorithm, key_count in txn: result[algorithm] = key_count return result return self.runInteraction( @@ -137,7 +215,7 @@ class EndToEndKeyStore(SQLBaseStore): user_result = result.setdefault(user_id, {}) device_result = user_result.setdefault(device_id, {}) txn.execute(sql, (user_id, device_id, algorithm)) - for key_id, key_json in txn.fetchall(): + for key_id, key_json in txn: device_result[algorithm + ":" + key_id] = key_json delete.append((user_id, device_id, algorithm, key_id)) sql = ( @@ -152,7 +230,7 @@ class EndToEndKeyStore(SQLBaseStore): "claim_e2e_one_time_keys", _claim_e2e_one_time_keys ) - @twisted.internet.defer.inlineCallbacks + @defer.inlineCallbacks def delete_e2e_keys_by_device(self, user_id, device_id): yield self._simple_delete( table="e2e_device_keys_json", diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index f0aa2193fb..43b5b49986 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -74,7 +74,7 @@ class EventFederationStore(SQLBaseStore): base_sql % (",".join(["?"] * len(chunk)),), chunk ) - new_front.update([r[0] for r in txn.fetchall()]) + new_front.update([r[0] for r in txn]) new_front -= results @@ -110,7 +110,7 @@ class EventFederationStore(SQLBaseStore): txn.execute(sql, (room_id, False,)) - return dict(txn.fetchall()) + return dict(txn) def _get_oldest_events_in_room_txn(self, txn, room_id): return self._simple_select_onecol_txn( @@ -129,7 +129,7 @@ class EventFederationStore(SQLBaseStore): room_id, ) - @cached() + @cached(max_entries=5000, iterable=True) def get_latest_event_ids_in_room(self, room_id): return self._simple_select_onecol( table="event_forward_extremities", @@ -152,7 +152,7 @@ class EventFederationStore(SQLBaseStore): txn.execute(sql, (room_id, )) results = [] - for event_id, depth in txn.fetchall(): + for event_id, depth in txn: hashes = self._get_event_reference_hashes_txn(txn, event_id) prev_hashes = { k: encode_base64(v) for k, v in hashes.items() @@ -201,19 +201,19 @@ class EventFederationStore(SQLBaseStore): def _update_min_depth_for_room_txn(self, txn, room_id, depth): min_depth = self._get_min_depth_interaction(txn, room_id) - do_insert = depth < min_depth if min_depth else True + if min_depth and depth >= min_depth: + return - if do_insert: - self._simple_upsert_txn( - txn, - table="room_depth", - keyvalues={ - "room_id": room_id, - }, - values={ - "min_depth": depth, - }, - ) + self._simple_upsert_txn( + txn, + table="room_depth", + keyvalues={ + "room_id": room_id, + }, + values={ + "min_depth": depth, + }, + ) def _handle_mult_prev_events(self, txn, events): """ @@ -281,15 +281,30 @@ class EventFederationStore(SQLBaseStore): ) def get_forward_extremeties_for_room(self, room_id, stream_ordering): + """For a given room_id and stream_ordering, return the forward + extremeties of the room at that point in "time". + + Throws a StoreError if we have since purged the index for + stream_orderings from that point. + + Args: + room_id (str): + stream_ordering (int): + + Returns: + deferred, which resolves to a list of event_ids + """ # We want to make the cache more effective, so we clamp to the last # change before the given ordering. last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # We don't always have a full stream_to_exterm_id table, e.g. after # the upgrade that introduced it, so we make sure we never ask for a - # try and pin to a stream_ordering from before a restart + # stream_ordering from before a restart last_change = max(self._stream_order_on_start, last_change) + # provided the last_change is recent enough, we now clamp the requested + # stream_ordering to it. if last_change > self.stream_ordering_month_ago: stream_ordering = min(last_change, stream_ordering) @@ -319,8 +334,7 @@ class EventFederationStore(SQLBaseStore): def get_forward_extremeties_for_room_txn(txn): txn.execute(sql, (stream_ordering, room_id)) - rows = txn.fetchall() - return [event_id for event_id, in rows] + return [event_id for event_id, in txn] return self.runInteraction( "get_forward_extremeties_for_room", @@ -421,7 +435,7 @@ class EventFederationStore(SQLBaseStore): (room_id, event_id, False, limit - len(event_results)) ) - for row in txn.fetchall(): + for row in txn: if row[1] not in event_results: queue.put((-row[0], row[1])) @@ -467,7 +481,7 @@ class EventFederationStore(SQLBaseStore): (room_id, event_id, False, limit - len(event_results)) ) - for e_id, in txn.fetchall(): + for e_id, in txn: new_front.add(e_id) new_front -= earliest_events diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py index 7de3e8c58c..d6d8723b4a 100644 --- a/synapse/storage/event_push_actions.py +++ b/synapse/storage/event_push_actions.py @@ -15,6 +15,7 @@ from ._base import SQLBaseStore from twisted.internet import defer +from synapse.util.async import sleep from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.types import RoomStreamToken from .stream import lower_bound @@ -25,11 +26,46 @@ import ujson as json logger = logging.getLogger(__name__) +DEFAULT_NOTIF_ACTION = ["notify", {"set_tweak": "highlight", "value": False}] +DEFAULT_HIGHLIGHT_ACTION = [ + "notify", {"set_tweak": "sound", "value": "default"}, {"set_tweak": "highlight"} +] + + +def _serialize_action(actions, is_highlight): + """Custom serializer for actions. This allows us to "compress" common actions. + + We use the fact that most users have the same actions for notifs (and for + highlights). + We store these default actions as the empty string rather than the full JSON. + Since the empty string isn't valid JSON there is no risk of this clashing with + any real JSON actions + """ + if is_highlight: + if actions == DEFAULT_HIGHLIGHT_ACTION: + return "" # We use empty string as the column is non-NULL + else: + if actions == DEFAULT_NOTIF_ACTION: + return "" + return json.dumps(actions) + + +def _deserialize_action(actions, is_highlight): + """Custom deserializer for actions. This allows us to "compress" common actions + """ + if actions: + return json.loads(actions) + + if is_highlight: + return DEFAULT_HIGHLIGHT_ACTION + else: + return DEFAULT_NOTIF_ACTION + + class EventPushActionsStore(SQLBaseStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" def __init__(self, hs): - self.stream_ordering_month_ago = None super(EventPushActionsStore, self).__init__(hs) self.register_background_index_update( @@ -47,6 +83,11 @@ class EventPushActionsStore(SQLBaseStore): where_clause="highlight=1" ) + self._doing_notif_rotation = False + self._rotate_notif_loop = self._clock.looping_call( + self._rotate_notifs, 30 * 60 * 1000 + ) + def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples): """ Args: @@ -55,15 +96,17 @@ class EventPushActionsStore(SQLBaseStore): """ values = [] for uid, actions in tuples: + is_highlight = 1 if _action_has_highlight(actions) else 0 + values.append({ 'room_id': event.room_id, 'event_id': event.event_id, 'user_id': uid, - 'actions': json.dumps(actions), + 'actions': _serialize_action(actions, is_highlight), 'stream_ordering': event.internal_metadata.stream_ordering, 'topological_ordering': event.depth, 'notif': 1, - 'highlight': 1 if _action_has_highlight(actions) else 0, + 'highlight': is_highlight, }) for uid, __ in tuples: @@ -77,66 +120,83 @@ class EventPushActionsStore(SQLBaseStore): def get_unread_event_push_actions_by_room_for_user( self, room_id, user_id, last_read_event_id ): - def _get_unread_event_push_actions_by_room(txn): - sql = ( - "SELECT stream_ordering, topological_ordering" - " FROM events" - " WHERE room_id = ? AND event_id = ?" - ) - txn.execute( - sql, (room_id, last_read_event_id) - ) - results = txn.fetchall() - if len(results) == 0: - return {"notify_count": 0, "highlight_count": 0} - - stream_ordering = results[0][0] - topological_ordering = results[0][1] - token = RoomStreamToken( - topological_ordering, stream_ordering - ) - - # First get number of notifications. - # We don't need to put a notif=1 clause as all rows always have - # notif=1 - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " user_id = ?" - " AND room_id = ?" - " AND %s" - ) % (lower_bound(token, self.database_engine, inclusive=False),) - - txn.execute(sql, (user_id, room_id)) - row = txn.fetchone() - notify_count = row[0] if row else 0 + ret = yield self.runInteraction( + "get_unread_event_push_actions_by_room", + self._get_unread_counts_by_receipt_txn, + room_id, user_id, last_read_event_id + ) + defer.returnValue(ret) - # Now get the number of highlights - sql = ( - "SELECT count(*)" - " FROM event_push_actions ea" - " WHERE" - " highlight = 1" - " AND user_id = ?" - " AND room_id = ?" - " AND %s" - ) % (lower_bound(token, self.database_engine, inclusive=False),) + def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id, + last_read_event_id): + sql = ( + "SELECT stream_ordering, topological_ordering" + " FROM events" + " WHERE room_id = ? AND event_id = ?" + ) + txn.execute( + sql, (room_id, last_read_event_id) + ) + results = txn.fetchall() + if len(results) == 0: + return {"notify_count": 0, "highlight_count": 0} - txn.execute(sql, (user_id, room_id)) - row = txn.fetchone() - highlight_count = row[0] if row else 0 + stream_ordering = results[0][0] + topological_ordering = results[0][1] - return { - "notify_count": notify_count, - "highlight_count": highlight_count, - } + return self._get_unread_counts_by_pos_txn( + txn, room_id, user_id, topological_ordering, stream_ordering + ) - ret = yield self.runInteraction( - "get_unread_event_push_actions_by_room", - _get_unread_event_push_actions_by_room + def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering, + stream_ordering): + token = RoomStreamToken( + topological_ordering, stream_ordering ) - defer.returnValue(ret) + + # First get number of notifications. + # We don't need to put a notif=1 clause as all rows always have + # notif=1 + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + notify_count = row[0] if row else 0 + + txn.execute(""" + SELECT notif_count FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering > ? + """, (room_id, user_id, stream_ordering,)) + rows = txn.fetchall() + if rows: + notify_count += rows[0][0] + + # Now get the number of highlights + sql = ( + "SELECT count(*)" + " FROM event_push_actions ea" + " WHERE" + " highlight = 1" + " AND user_id = ?" + " AND room_id = ?" + " AND %s" + ) % (lower_bound(token, self.database_engine, inclusive=False),) + + txn.execute(sql, (user_id, room_id)) + row = txn.fetchone() + highlight_count = row[0] if row else 0 + + return { + "notify_count": notify_count, + "highlight_count": highlight_count, + } @defer.inlineCallbacks def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering): @@ -146,7 +206,7 @@ class EventPushActionsStore(SQLBaseStore): " stream_ordering >= ? AND stream_ordering <= ?" ) txn.execute(sql, (min_stream_ordering, max_stream_ordering)) - return [r[0] for r in txn.fetchall()] + return [r[0] for r in txn] ret = yield self.runInteraction("get_push_action_users_in_range", f) defer.returnValue(ret) @@ -176,7 +236,8 @@ class EventPushActionsStore(SQLBaseStore): # find rooms that have a read receipt in them and return the next # push actions sql = ( - "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions" + "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," + " ep.highlight " " FROM (" " SELECT room_id," " MAX(topological_ordering) as topological_ordering," @@ -217,7 +278,7 @@ class EventPushActionsStore(SQLBaseStore): def get_no_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight " " FROM event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE" @@ -246,7 +307,7 @@ class EventPushActionsStore(SQLBaseStore): "event_id": row[0], "room_id": row[1], "stream_ordering": row[2], - "actions": json.loads(row[3]), + "actions": _deserialize_action(row[3], row[4]), } for row in after_read_receipt + no_read_receipt ] @@ -285,7 +346,7 @@ class EventPushActionsStore(SQLBaseStore): def get_after_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight, e.received_ts" " FROM (" " SELECT room_id," " MAX(topological_ordering) as topological_ordering," @@ -327,7 +388,7 @@ class EventPushActionsStore(SQLBaseStore): def get_no_receipt(txn): sql = ( "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," - " e.received_ts" + " ep.highlight, e.received_ts" " FROM event_push_actions AS ep" " INNER JOIN events AS e USING (room_id, event_id)" " WHERE" @@ -357,8 +418,8 @@ class EventPushActionsStore(SQLBaseStore): "event_id": row[0], "room_id": row[1], "stream_ordering": row[2], - "actions": json.loads(row[3]), - "received_ts": row[4], + "actions": _deserialize_action(row[3], row[4]), + "received_ts": row[5], } for row in after_read_receipt + no_read_receipt ] @@ -392,7 +453,7 @@ class EventPushActionsStore(SQLBaseStore): sql = ( "SELECT epa.event_id, epa.room_id," " epa.stream_ordering, epa.topological_ordering," - " epa.actions, epa.profile_tag, e.received_ts" + " epa.actions, epa.highlight, epa.profile_tag, e.received_ts" " FROM event_push_actions epa, events e" " WHERE epa.event_id = e.event_id" " AND epa.user_id = ? %s" @@ -407,7 +468,7 @@ class EventPushActionsStore(SQLBaseStore): "get_push_actions_for_user", f ) for pa in push_actions: - pa["actions"] = json.loads(pa["actions"]) + pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) defer.returnValue(push_actions) @defer.inlineCallbacks @@ -448,10 +509,14 @@ class EventPushActionsStore(SQLBaseStore): ) def _remove_old_push_actions_before_txn(self, txn, room_id, user_id, - topological_ordering): + topological_ordering, stream_ordering): """ - Purges old, stale push actions for a user and room before a given - topological_ordering + Purges old push actions for a user and room before a given + topological_ordering. + + We however keep a months worth of highlighted notifications, so that + users can still get a list of recent highlights. + Args: txn: The transcation room_id: Room ID to delete from @@ -475,10 +540,16 @@ class EventPushActionsStore(SQLBaseStore): txn.execute( "DELETE FROM event_push_actions " " WHERE user_id = ? AND room_id = ? AND " - " topological_ordering < ? AND stream_ordering < ?", + " topological_ordering <= ?" + " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", (user_id, room_id, topological_ordering, self.stream_ordering_month_ago) ) + txn.execute(""" + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, (room_id, user_id, stream_ordering)) + @defer.inlineCallbacks def _find_stream_orderings_for_times(self): yield self.runInteraction( @@ -495,6 +566,14 @@ class EventPushActionsStore(SQLBaseStore): "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago ) + logger.info("Searching for stream ordering 1 day ago") + self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( + txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 + ) + logger.info( + "Found stream ordering 1 day ago: it's %d", + self.stream_ordering_day_ago + ) def _find_first_stream_ordering_after_ts_txn(self, txn, ts): """ @@ -534,6 +613,131 @@ class EventPushActionsStore(SQLBaseStore): return range_end + @defer.inlineCallbacks + def _rotate_notifs(self): + if self._doing_notif_rotation or self.stream_ordering_day_ago is None: + return + self._doing_notif_rotation = True + + try: + while True: + logger.info("Rotating notifications") + + caught_up = yield self.runInteraction( + "_rotate_notifs", + self._rotate_notifs_txn + ) + if caught_up: + break + yield sleep(5) + finally: + self._doing_notif_rotation = False + + def _rotate_notifs_txn(self, txn): + """Archives older notifications into event_push_summary. Returns whether + the archiving process has caught up or not. + """ + + old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + # We don't to try and rotate millions of rows at once, so we cap the + # maximum stream ordering we'll rotate before. + txn.execute(""" + SELECT stream_ordering FROM event_push_actions + WHERE stream_ordering > ? + ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000 + """, (old_rotate_stream_ordering,)) + stream_row = txn.fetchone() + if stream_row: + offset_stream_ordering, = stream_row + rotate_to_stream_ordering = min( + self.stream_ordering_day_ago, offset_stream_ordering + ) + caught_up = offset_stream_ordering >= self.stream_ordering_day_ago + else: + rotate_to_stream_ordering = self.stream_ordering_day_ago + caught_up = True + + logger.info("Rotating notifications up to: %s", rotate_to_stream_ordering) + + self._rotate_notifs_before_txn(txn, rotate_to_stream_ordering) + + # We have caught up iff we were limited by `stream_ordering_day_ago` + return caught_up + + def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering): + old_rotate_stream_ordering = self._simple_select_one_onecol_txn( + txn, + table="event_push_summary_stream_ordering", + keyvalues={}, + retcol="stream_ordering", + ) + + # Calculate the new counts that should be upserted into event_push_summary + sql = """ + SELECT user_id, room_id, + coalesce(old.notif_count, 0) + upd.notif_count, + upd.stream_ordering, + old.user_id + FROM ( + SELECT user_id, room_id, count(*) as notif_count, + max(stream_ordering) as stream_ordering + FROM event_push_actions + WHERE ? <= stream_ordering AND stream_ordering < ? + AND highlight = 0 + GROUP BY user_id, room_id + ) AS upd + LEFT JOIN event_push_summary AS old USING (user_id, room_id) + """ + + txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering,)) + rows = txn.fetchall() + + logger.info("Rotating notifications, handling %d rows", len(rows)) + + # If the `old.user_id` above is NULL then we know there isn't already an + # entry in the table, so we simply insert it. Otherwise we update the + # existing table. + self._simple_insert_many_txn( + txn, + table="event_push_summary", + values=[ + { + "user_id": row[0], + "room_id": row[1], + "notif_count": row[2], + "stream_ordering": row[3], + } + for row in rows if row[4] is None + ] + ) + + txn.executemany( + """ + UPDATE event_push_summary SET notif_count = ?, stream_ordering = ? + WHERE user_id = ? AND room_id = ? + """, + ((row[2], row[3], row[0], row[1],) for row in rows if row[4] is not None) + ) + + txn.execute( + "DELETE FROM event_push_actions" + " WHERE ? <= stream_ordering AND stream_ordering < ? AND highlight = 0", + (old_rotate_stream_ordering, rotate_to_stream_ordering,) + ) + + logger.info("Rotating notifications, deleted %s push actions", txn.rowcount) + + txn.execute( + "UPDATE event_push_summary_stream_ordering SET stream_ordering = ?", + (rotate_to_stream_ordering,) + ) + def _action_has_highlight(actions): for action in actions: diff --git a/synapse/storage/events.py b/synapse/storage/events.py index 6160949f32..3f6833fad2 100644 --- a/synapse/storage/events.py +++ b/synapse/storage/events.py @@ -28,19 +28,22 @@ from synapse.util.metrics import Measure from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.state import resolve_events +from synapse.util.caches.descriptors import cached from canonicaljson import encode_canonical_json from collections import deque, namedtuple, OrderedDict from functools import wraps -import synapse import synapse.metrics - import logging import math import ujson as json +# these are only included to make the type annotations work +from synapse.events import EventBase # noqa: F401 +from synapse.events.snapshot import EventContext # noqa: F401 + logger = logging.getLogger(__name__) @@ -81,6 +84,11 @@ class _EventPeristenceQueue(object): def add_to_queue(self, room_id, events_and_contexts, backfilled): """Add events to the queue, with the given persist_event options. + + Args: + room_id (str): + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): """ queue = self._event_persist_queues.setdefault(room_id, deque()) if queue: @@ -209,14 +217,14 @@ class EventsStore(SQLBaseStore): partitioned.setdefault(event.room_id, []).append((event, ctx)) deferreds = [] - for room_id, evs_ctxs in partitioned.items(): + for room_id, evs_ctxs in partitioned.iteritems(): d = preserve_fn(self._event_persist_queue.add_to_queue)( room_id, evs_ctxs, backfilled=backfilled, ) deferreds.append(d) - for room_id in partitioned.keys(): + for room_id in partitioned: self._maybe_start_persisting(room_id) return preserve_context_over_deferred( @@ -226,6 +234,17 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks @log_function def persist_event(self, event, context, backfilled=False): + """ + + Args: + event (EventBase): + context (EventContext): + backfilled (bool): + + Returns: + Deferred: resolves to (int, int): the stream ordering of ``event``, + and the stream ordering of the latest persisted event + """ deferred = self._event_persist_queue.add_to_queue( event.room_id, [(event, context)], backfilled=backfilled, @@ -252,6 +271,16 @@ class EventsStore(SQLBaseStore): @defer.inlineCallbacks def _persist_events(self, events_and_contexts, backfilled=False, delete_existing=False): + """Persist events to db + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + backfilled (bool): + delete_existing (bool): + + Returns: + Deferred: resolves when the events have been persisted + """ if not events_and_contexts: return @@ -284,71 +313,52 @@ class EventsStore(SQLBaseStore): new_forward_extremeties = {} current_state_for_room = {} if not backfilled: - # Work out the new "current state" for each room. - # We do this by working out what the new extremities are and then - # calculating the state from that. - events_by_room = {} - for event, context in chunk: - events_by_room.setdefault(event.room_id, []).append( - (event, context) - ) - - for room_id, ev_ctx_rm in events_by_room.items(): - # Work out new extremities by recursively adding and removing - # the new events. - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) - new_latest_event_ids = yield self._calculate_new_extremeties( - room_id, [ev for ev, _ in ev_ctx_rm] - ) - - if new_latest_event_ids == set(latest_event_ids): - # No change in extremities, so no change in state - continue + with Measure(self._clock, "_calculate_state_and_extrem"): + # Work out the new "current state" for each room. + # We do this by working out what the new extremities are and then + # calculating the state from that. + events_by_room = {} + for event, context in chunk: + events_by_room.setdefault(event.room_id, []).append( + (event, context) + ) - new_forward_extremeties[room_id] = new_latest_event_ids - - # Now we need to work out the different state sets for - # each state extremities - state_sets = [] - missing_event_ids = [] - was_updated = False - for event_id in new_latest_event_ids: - # First search in the list of new events we're adding, - # and then use the current state from that - for ev, ctx in ev_ctx_rm: - if event_id == ev.event_id: - if ctx.current_state_ids is None: - raise Exception("Unknown current state") - state_sets.append(ctx.current_state_ids) - if ctx.delta_ids or hasattr(ev, "state_key"): - was_updated = True - break - else: - # If we couldn't find it, then we'll need to pull - # the state from the database - was_updated = True - missing_event_ids.append(event_id) - - if missing_event_ids: - # Now pull out the state for any missing events from DB - event_to_groups = yield self._get_state_group_for_events( - missing_event_ids, + for room_id, ev_ctx_rm in events_by_room.iteritems(): + # Work out new extremities by recursively adding and removing + # the new events. + latest_event_ids = yield self.get_latest_event_ids_in_room( + room_id + ) + new_latest_event_ids = yield self._calculate_new_extremeties( + room_id, ev_ctx_rm, latest_event_ids ) - groups = set(event_to_groups.values()) - group_to_state = yield self._get_state_for_groups(groups) + if new_latest_event_ids == set(latest_event_ids): + # No change in extremities, so no change in state + continue - state_sets.extend(group_to_state.values()) + new_forward_extremeties[room_id] = new_latest_event_ids - if not new_latest_event_ids or was_updated: - current_state_for_room[room_id] = yield resolve_events( - state_sets, - state_map_factory=lambda ev_ids: self.get_events( - ev_ids, get_prev_content=False, check_redacted=False, - ), + len_1 = ( + len(latest_event_ids) == 1 + and len(new_latest_event_ids) == 1 ) + if len_1: + all_single_prev_not_state = all( + len(event.prev_events) == 1 + and not event.is_state() + for event, ctx in ev_ctx_rm + ) + # Don't bother calculating state if they're just + # a long chain of single ancestor non-state events. + if all_single_prev_not_state: + continue + + state = yield self._calculate_state_delta( + room_id, ev_ctx_rm, new_latest_event_ids + ) + if state: + current_state_for_room[room_id] = state yield self.runInteraction( "persist_events", @@ -362,27 +372,24 @@ class EventsStore(SQLBaseStore): persist_event_counter.inc_by(len(chunk)) @defer.inlineCallbacks - def _calculate_new_extremeties(self, room_id, events): + def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids): """Calculates the new forward extremeties for a room given events to persist. Assumes that we are only persisting events for one room at a time. """ - latest_event_ids = yield self.get_latest_event_ids_in_room( - room_id - ) new_latest_event_ids = set(latest_event_ids) # First, add all the new events to the list new_latest_event_ids.update( - event.event_id for event in events - if not event.internal_metadata.is_outlier() + event.event_id for event, ctx in event_contexts + if not event.internal_metadata.is_outlier() and not ctx.rejected ) # Now remove all events that are referenced by the to-be-added events new_latest_event_ids.difference_update( e_id - for event in events + for event, ctx in event_contexts for e_id, _ in event.prev_events - if not event.internal_metadata.is_outlier() + if not event.internal_metadata.is_outlier() and not ctx.rejected ) # And finally remove any events that are referenced by previously added @@ -406,6 +413,125 @@ class EventsStore(SQLBaseStore): defer.returnValue(new_latest_event_ids) @defer.inlineCallbacks + def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids): + """Calculate the new state deltas for a room. + + Assumes that we are only persisting events for one room at a time. + + Returns: + 2-tuple (to_delete, to_insert) where both are state dicts, i.e. + (type, state_key) -> event_id. `to_delete` are the entries to + first be deleted from current_state_events, `to_insert` are entries + to insert. + May return None if there are no changes to be applied. + """ + # Now we need to work out the different state sets for + # each state extremities + state_sets = [] + state_groups = set() + missing_event_ids = [] + was_updated = False + for event_id in new_latest_event_ids: + # First search in the list of new events we're adding, + # and then use the current state from that + for ev, ctx in events_context: + if event_id == ev.event_id: + if ctx.current_state_ids is None: + raise Exception("Unknown current state") + + # If we've already seen the state group don't bother adding + # it to the state sets again + if ctx.state_group not in state_groups: + state_sets.append(ctx.current_state_ids) + if ctx.delta_ids or hasattr(ev, "state_key"): + was_updated = True + if ctx.state_group: + # Add this as a seen state group (if it has a state + # group) + state_groups.add(ctx.state_group) + break + else: + # If we couldn't find it, then we'll need to pull + # the state from the database + was_updated = True + missing_event_ids.append(event_id) + + if missing_event_ids: + # Now pull out the state for any missing events from DB + event_to_groups = yield self._get_state_group_for_events( + missing_event_ids, + ) + + groups = set(event_to_groups.itervalues()) - state_groups + + if groups: + group_to_state = yield self._get_state_for_groups(groups) + state_sets.extend(group_to_state.itervalues()) + + if not new_latest_event_ids: + current_state = {} + elif was_updated: + if len(state_sets) == 1: + # If there is only one state set, then we know what the current + # state is. + current_state = state_sets[0] + else: + # We work out the current state by passing the state sets to the + # state resolution algorithm. It may ask for some events, including + # the events we have yet to persist, so we need a slightly more + # complicated event lookup function than simply looking the events + # up in the db. + events_map = {ev.event_id: ev for ev, _ in events_context} + + @defer.inlineCallbacks + def get_events(ev_ids): + # We get the events by first looking at the list of events we + # are trying to persist, and then fetching the rest from the DB. + db = [] + to_return = {} + for ev_id in ev_ids: + ev = events_map.get(ev_id, None) + if ev: + to_return[ev_id] = ev + else: + db.append(ev_id) + + if db: + evs = yield self.get_events( + ev_ids, get_prev_content=False, check_redacted=False, + ) + to_return.update(evs) + defer.returnValue(to_return) + + current_state = yield resolve_events( + state_sets, + state_map_factory=get_events, + ) + else: + return + + existing_state = yield self.get_current_state_ids(room_id) + + existing_events = set(existing_state.itervalues()) + new_events = set(ev_id for ev_id in current_state.itervalues()) + changed_events = existing_events ^ new_events + + if not changed_events: + return + + to_delete = { + key: ev_id for key, ev_id in existing_state.iteritems() + if ev_id in changed_events + } + events_to_insert = (new_events - existing_events) + to_insert = { + key: ev_id for key, ev_id in current_state.iteritems() + if ev_id in events_to_insert + } + + defer.returnValue((to_delete, to_insert)) + + @defer.inlineCallbacks def get_event(self, event_id, check_redacted=True, get_prev_content=False, allow_rejected=False, allow_none=False): @@ -471,44 +597,143 @@ class EventsStore(SQLBaseStore): and the rejections table. Things reading from those table will need to check whether the event was rejected. - If delete_existing is True then existing events will be purged from the - database before insertion. This is useful when retrying due to IntegrityError. + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): + events to persist + backfilled (bool): True if the events were backfilled + delete_existing (bool): True to purge existing table rows for the + events from the database. This is useful when retrying due to + IntegrityError. + current_state_for_room (dict[str, (list[str], list[str])]): + The current-state delta for each room. For each room, a tuple + (to_delete, to_insert), being a list of event ids to be removed + from the current state, and a list of event ids to be added to + the current state. + new_forward_extremeties (dict[str, list[str]]): + The new forward extremities for each room. For each room, a + list of the event ids which are the forward extremities. + """ + self._update_current_state_txn(txn, current_state_for_room) + max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering - for room_id, current_state in current_state_for_room.iteritems(): - txn.call_after(self._get_current_state_for_key.invalidate_all) - txn.call_after(self.get_rooms_for_user.invalidate_all) - txn.call_after(self.get_users_in_room.invalidate, (room_id,)) - - # Add an entry to the current_state_resets table to record the point - # where we clobbered the current state - self._simple_insert_txn( - txn, - table="current_state_resets", - values={"event_stream_ordering": max_stream_order} - ) + self._update_forward_extremities_txn( + txn, + new_forward_extremities=new_forward_extremeties, + max_stream_order=max_stream_order, + ) - self._simple_delete_txn( - txn, - table="current_state_events", - keyvalues={"room_id": room_id}, - ) + # Ensure that we don't have the same event twice. + events_and_contexts = self._filter_events_and_contexts_for_duplicates( + events_and_contexts, + ) - self._simple_insert_many_txn( + self._update_room_depths_txn( + txn, + events_and_contexts=events_and_contexts, + backfilled=backfilled, + ) + + # _update_outliers_txn filters out any events which have already been + # persisted, and returns the filtered list. + events_and_contexts = self._update_outliers_txn( + txn, + events_and_contexts=events_and_contexts, + ) + + # From this point onwards the events are only events that we haven't + # seen before. + + if delete_existing: + # For paranoia reasons, we go and delete all the existing entries + # for these events so we can reinsert them. + # This gets around any problems with some tables already having + # entries. + self._delete_existing_rows_txn( txn, - table="current_state_events", - values=[ - { - "event_id": ev_id, - "room_id": room_id, - "type": key[0], - "state_key": key[1], - } - for key, ev_id in current_state.iteritems() - ], + events_and_contexts=events_and_contexts, ) - for room_id, new_extrem in new_forward_extremeties.items(): + self._store_event_txn( + txn, + events_and_contexts=events_and_contexts, + ) + + # Insert into the state_groups, state_groups_state, and + # event_to_state_groups tables. + self._store_mult_state_groups_txn(txn, events_and_contexts) + + # _store_rejected_events_txn filters out any events which were + # rejected, and returns the filtered list. + events_and_contexts = self._store_rejected_events_txn( + txn, + events_and_contexts=events_and_contexts, + ) + + # From this point onwards the events are only ones that weren't + # rejected. + + self._update_metadata_tables_txn( + txn, + events_and_contexts=events_and_contexts, + backfilled=backfilled, + ) + + def _update_current_state_txn(self, txn, state_delta_by_room): + for room_id, current_state_tuple in state_delta_by_room.iteritems(): + to_delete, to_insert = current_state_tuple + txn.executemany( + "DELETE FROM current_state_events WHERE event_id = ?", + [(ev_id,) for ev_id in to_delete.itervalues()], + ) + + self._simple_insert_many_txn( + txn, + table="current_state_events", + values=[ + { + "event_id": ev_id, + "room_id": room_id, + "type": key[0], + "state_key": key[1], + } + for key, ev_id in to_insert.iteritems() + ], + ) + + # Invalidate the various caches + + # Figure out the changes of membership to invalidate the + # `get_rooms_for_user` cache. + # We find out which membership events we may have deleted + # and which we have added, then we invlidate the caches for all + # those users. + members_changed = set( + state_key for ev_type, state_key in to_delete.iterkeys() + if ev_type == EventTypes.Member + ) + members_changed.update( + state_key for ev_type, state_key in to_insert.iterkeys() + if ev_type == EventTypes.Member + ) + + for member in members_changed: + self._invalidate_cache_and_stream( + txn, self.get_rooms_for_user, (member,) + ) + + self._invalidate_cache_and_stream( + txn, self.get_users_in_room, (room_id,) + ) + + self._invalidate_cache_and_stream( + txn, self.get_current_state_ids, (room_id,) + ) + + def _update_forward_extremities_txn(self, txn, new_forward_extremities, + max_stream_order): + for room_id, new_extrem in new_forward_extremities.iteritems(): self._simple_delete_txn( txn, table="event_forward_extremities", @@ -526,7 +751,7 @@ class EventsStore(SQLBaseStore): "event_id": ev_id, "room_id": room_id, } - for room_id, new_extrem in new_forward_extremeties.items() + for room_id, new_extrem in new_forward_extremities.iteritems() for ev_id in new_extrem ], ) @@ -543,13 +768,22 @@ class EventsStore(SQLBaseStore): "event_id": event_id, "stream_ordering": max_stream_order, } - for room_id, new_extrem in new_forward_extremeties.items() + for room_id, new_extrem in new_forward_extremities.iteritems() for event_id in new_extrem ] ) - # Ensure that we don't have the same event twice. - # Pick the earliest non-outlier if there is one, else the earliest one. + @classmethod + def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts): + """Ensure that we don't have the same event twice. + + Pick the earliest non-outlier if there is one, else the earliest one. + + Args: + events_and_contexts (list[(EventBase, EventContext)]): + Returns: + list[(EventBase, EventContext)]: filtered list + """ new_events_and_contexts = OrderedDict() for event, context in events_and_contexts: prev_event_context = new_events_and_contexts.get(event.event_id) @@ -562,9 +796,17 @@ class EventsStore(SQLBaseStore): new_events_and_contexts[event.event_id] = (event, context) else: new_events_and_contexts[event.event_id] = (event, context) + return new_events_and_contexts.values() - events_and_contexts = new_events_and_contexts.values() + def _update_room_depths_txn(self, txn, events_and_contexts, backfilled): + """Update min_depth for each room + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + backfilled (bool): True if the events were backfilled + """ depth_updates = {} for event, context in events_and_contexts: # Remove the any existing cache entries for the event_ids @@ -580,9 +822,24 @@ class EventsStore(SQLBaseStore): event.depth, depth_updates.get(event.room_id, event.depth) ) - for room_id, depth in depth_updates.items(): + for room_id, depth in depth_updates.iteritems(): self._update_min_depth_for_room_txn(txn, room_id, depth) + def _update_outliers_txn(self, txn, events_and_contexts): + """Update any outliers with new event info. + + This turns outliers into ex-outliers (unless the new event was + rejected). + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without events which + are already in the events table. + """ txn.execute( "SELECT event_id, outlier FROM events WHERE event_id in (%s)" % ( ",".join(["?"] * len(events_and_contexts)), @@ -592,24 +849,21 @@ class EventsStore(SQLBaseStore): have_persisted = { event_id: outlier - for event_id, outlier in txn.fetchall() + for event_id, outlier in txn } to_remove = set() for event, context in events_and_contexts: - if context.rejected: - # If the event is rejected then we don't care if the event - # was an outlier or not. - if event.event_id in have_persisted: - # If we have already seen the event then ignore it. - to_remove.add(event) - continue - if event.event_id not in have_persisted: continue to_remove.add(event) + if context.rejected: + # If the event is rejected then we don't care if the event + # was an outlier or not. + continue + outlier_persisted = have_persisted[event.event_id] if not event.internal_metadata.is_outlier() and outlier_persisted: # We received a copy of an event that we had already stored as @@ -664,37 +918,19 @@ class EventsStore(SQLBaseStore): # event isn't an outlier any more. self._update_backward_extremeties(txn, [event]) - events_and_contexts = [ + return [ ec for ec in events_and_contexts if ec[0] not in to_remove ] + @classmethod + def _delete_existing_rows_txn(cls, txn, events_and_contexts): if not events_and_contexts: - # Make sure we don't pass an empty list to functions that expect to - # be storing at least one element. + # nothing to do here return - # From this point onwards the events are only events that we haven't - # seen before. + logger.info("Deleting existing") - def event_dict(event): - return { - k: v - for k, v in event.get_dict().items() - if k not in [ - "redacted", - "redacted_because", - ] - } - - if delete_existing: - # For paranoia reasons, we go and delete all the existing entries - # for these events so we can reinsert them. - # This gets around any problems with some tables already having - # entries. - - logger.info("Deleting existing") - - for table in ( + for table in ( "events", "event_auth", "event_json", @@ -717,11 +953,30 @@ class EventsStore(SQLBaseStore): "redactions", "room_memberships", "topics" - ): - txn.executemany( - "DELETE FROM %s WHERE event_id = ?" % (table,), - [(ev.event_id,) for ev, _ in events_and_contexts] - ) + ): + txn.executemany( + "DELETE FROM %s WHERE event_id = ?" % (table,), + [(ev.event_id,) for ev, _ in events_and_contexts] + ) + + def _store_event_txn(self, txn, events_and_contexts): + """Insert new events into the event and event_json tables + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + """ + + if not events_and_contexts: + # nothing to do here + return + + def event_dict(event): + d = event.get_dict() + d.pop("redacted", None) + d.pop("redacted_because", None) + return d self._simple_insert_many_txn( txn, @@ -765,6 +1020,19 @@ class EventsStore(SQLBaseStore): ], ) + def _store_rejected_events_txn(self, txn, events_and_contexts): + """Add rows to the 'rejections' table for received events which were + rejected + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + + Returns: + list[(EventBase, EventContext)] new list, without the rejected + events. + """ # Remove the rejected events from the list now that we've added them # to the events table and the events_json table. to_remove = set() @@ -776,17 +1044,24 @@ class EventsStore(SQLBaseStore): ) to_remove.add(event) - events_and_contexts = [ + return [ ec for ec in events_and_contexts if ec[0] not in to_remove ] + def _update_metadata_tables_txn(self, txn, events_and_contexts, backfilled): + """Update all the miscellaneous tables for new events + + Args: + txn (twisted.enterprise.adbapi.Connection): db connection + events_and_contexts (list[(EventBase, EventContext)]): events + we are persisting + backfilled (bool): True if the events were backfilled + """ + if not events_and_contexts: - # Make sure we don't pass an empty list to functions that expect to - # be storing at least one element. + # nothing to do here return - # From this point onwards the events are only ones that weren't rejected. - for event, context in events_and_contexts: # Insert all the push actions into the event_push_actions table. if context.push_actions: @@ -815,10 +1090,6 @@ class EventsStore(SQLBaseStore): ], ) - # Insert into the state_groups, state_groups_state, and - # event_to_state_groups tables. - self._store_mult_state_groups_txn(txn, events_and_contexts) - # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. self._handle_mult_prev_events( @@ -905,13 +1176,6 @@ class EventsStore(SQLBaseStore): # Prefill the event cache self._add_to_cache(txn, events_and_contexts) - if backfilled: - # Backfilled events come before the current state so we don't need - # to update the current state table - return - - return - def _add_to_cache(self, txn, events_and_contexts): to_prefill = [] @@ -1507,6 +1771,7 @@ class EventsStore(SQLBaseStore): """The current minimum token that backfilled events have reached""" return -self._backfill_id_gen.get_current_token() + @cached(num_args=5, max_entries=10) def get_all_new_events(self, last_backfill_id, last_forward_id, current_backfill_id, current_forward_id, limit): """Get all the new events that have arrived at the server either as @@ -1519,14 +1784,13 @@ class EventsStore(SQLBaseStore): def get_all_new_events_txn(txn): sql = ( - "SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group" - " FROM events as e" - " JOIN event_json as ej" - " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" - " LEFT JOIN event_to_state_groups as eg" - " ON e.event_id = eg.event_id" - " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?" - " ORDER BY e.stream_ordering ASC" + "SELECT e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? < stream_ordering AND stream_ordering <= ?" + " ORDER BY stream_ordering ASC" " LIMIT ?" ) if have_forward_events: @@ -1539,15 +1803,6 @@ class EventsStore(SQLBaseStore): upper_bound = current_forward_id sql = ( - "SELECT event_stream_ordering FROM current_state_resets" - " WHERE ? < event_stream_ordering" - " AND event_stream_ordering <= ?" - " ORDER BY event_stream_ordering ASC" - ) - txn.execute(sql, (last_forward_id, upper_bound)) - state_resets = txn.fetchall() - - sql = ( "SELECT event_stream_ordering, event_id, state_group" " FROM ex_outlier_stream" " WHERE ? > event_stream_ordering" @@ -1558,19 +1813,16 @@ class EventsStore(SQLBaseStore): forward_ex_outliers = txn.fetchall() else: new_forward_events = [] - state_resets = [] forward_ex_outliers = [] sql = ( - "SELECT -e.stream_ordering, ej.internal_metadata, ej.json," - " eg.state_group" - " FROM events as e" - " JOIN event_json as ej" - " ON e.event_id = ej.event_id AND e.room_id = ej.room_id" - " LEFT JOIN event_to_state_groups as eg" - " ON e.event_id = eg.event_id" - " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?" - " ORDER BY e.stream_ordering DESC" + "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type," + " state_key, redacts" + " FROM events AS e" + " LEFT JOIN redactions USING (event_id)" + " LEFT JOIN state_events USING (event_id)" + " WHERE ? > stream_ordering AND stream_ordering >= ?" + " ORDER BY stream_ordering DESC" " LIMIT ?" ) if have_backfill_events: @@ -1598,7 +1850,6 @@ class EventsStore(SQLBaseStore): return AllNewEventsResult( new_forward_events, new_backfill_events, forward_ex_outliers, backward_ex_outliers, - state_resets, ) return self.runInteraction("get_all_new_events", get_all_new_events_txn) @@ -1758,7 +2009,7 @@ class EventsStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in curr_state.items() + for key, state_id in curr_state.iteritems() ], ) @@ -1824,5 +2075,4 @@ class EventsStore(SQLBaseStore): AllNewEventsResult = namedtuple("AllNewEventsResult", [ "new_forward_events", "new_backfill_events", "forward_ex_outliers", "backward_ex_outliers", - "state_resets" ]) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 86b37b9ddd..3b5e0a4fb9 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -101,9 +101,10 @@ class KeyStore(SQLBaseStore): key_ids Args: server_name (str): The name of the server. - key_ids (list of str): List of key_ids to try and look up. + key_ids (iterable[str]): key_ids to try and look up. Returns: - (list of VerifyKey): The verification keys. + Deferred: resolves to dict[str, VerifyKey]: map from + key_id to verification key. """ keys = {} for key_id in key_ids: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index b357f22be7..6e623843d5 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 40 +SCHEMA_VERSION = 41 dir_path = os.path.abspath(os.path.dirname(__file__)) @@ -356,7 +356,7 @@ def _get_or_create_schema_state(txn, database_engine): ), (current_version,) ) - applied_deltas = [d for d, in txn.fetchall()] + applied_deltas = [d for d, in txn] return current_version, applied_deltas, upgraded return None diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py index 7460f98a1f..9e9d3c2591 100644 --- a/synapse/storage/presence.py +++ b/synapse/storage/presence.py @@ -15,7 +15,7 @@ from ._base import SQLBaseStore from synapse.api.constants import PresenceState -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList from collections import namedtuple from twisted.internet import defer @@ -85,6 +85,9 @@ class PresenceStore(SQLBaseStore): self.presence_stream_cache.entity_has_changed, state.user_id, stream_id, ) + txn.call_after( + self._get_presence_for_user.invalidate, (state.user_id,) + ) # Actually insert new rows self._simple_insert_many_txn( @@ -143,7 +146,12 @@ class PresenceStore(SQLBaseStore): "get_all_presence_updates", get_all_presence_updates_txn ) - @defer.inlineCallbacks + @cached() + def _get_presence_for_user(self, user_id): + raise NotImplementedError() + + @cachedList(cached_method_name="_get_presence_for_user", list_name="user_ids", + num_args=1, inlineCallbacks=True) def get_presence_for_users(self, user_ids): rows = yield self._simple_select_many_batch( table="presence_stream", @@ -165,7 +173,7 @@ class PresenceStore(SQLBaseStore): for row in rows: row["currently_active"] = bool(row["currently_active"]) - defer.returnValue([UserPresenceState(**row) for row in rows]) + defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows}) def get_current_presence_token(self): return self._presence_id_gen.get_current_token() diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py index f72d15f5ed..6b0f8c2787 100644 --- a/synapse/storage/receipts.py +++ b/synapse/storage/receipts.py @@ -313,10 +313,9 @@ class ReceiptsStore(SQLBaseStore): ) txn.execute(sql, (room_id, receipt_type, user_id)) - results = txn.fetchall() - if results and topological_ordering: - for to, so, _ in results: + if topological_ordering: + for to, so, _ in txn: if int(to) > topological_ordering: return False elif int(to) == topological_ordering and int(so) >= stream_ordering: @@ -351,6 +350,7 @@ class ReceiptsStore(SQLBaseStore): room_id=room_id, user_id=user_id, topological_ordering=topological_ordering, + stream_ordering=stream_ordering, ) return True diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 26be6060c3..ec2c52ab93 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -209,7 +209,7 @@ class RegistrationStore(background_updates.BackgroundUpdateStore): " WHERE lower(name) = lower(?)" ) txn.execute(sql, (user_id,)) - return dict(txn.fetchall()) + return dict(txn) return self.runInteraction("get_users_by_id_case_insensitive", f) diff --git a/synapse/storage/room.py b/synapse/storage/room.py index 8a2fe2fdf5..e4c56cc175 100644 --- a/synapse/storage/room.py +++ b/synapse/storage/room.py @@ -396,7 +396,7 @@ class RoomStore(SQLBaseStore): sql % ("AND appservice_id IS NULL",), (stream_id,) ) - return dict(txn.fetchall()) + return dict(txn) else: # We want to get from all lists, so we need to aggregate the results @@ -422,7 +422,7 @@ class RoomStore(SQLBaseStore): results = {} # A room is visible if its visible on any list. - for room_id, visibility in txn.fetchall(): + for room_id, visibility in txn: results[room_id] = bool(visibility) or results.get(room_id, False) return results diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 768e0a4451..367dbbbcf6 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore): ) for event in events: - txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,)) - txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) txn.call_after( self._membership_stream_cache.entity_has_changed, event.state_key, event.internal_metadata.stream_ordering @@ -131,17 +129,30 @@ class RoomMemberStore(SQLBaseStore): with self._stream_id_gen.get_next() as stream_ordering: yield self.runInteraction("locally_reject_invite", f, stream_ordering) - @cached(max_entries=5000) + @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True) + def get_hosts_in_room(self, room_id, cache_context): + """Returns the set of all hosts currently in the room + """ + user_ids = yield self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate, + ) + hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids) + defer.returnValue(hosts) + + @cached(max_entries=500000, iterable=True) def get_users_in_room(self, room_id): def f(txn): - - rows = self._get_members_rows_txn( - txn, - room_id=room_id, - membership=Membership.JOIN, + sql = ( + "SELECT m.user_id FROM room_memberships as m" + " INNER JOIN current_state_events as c" + " ON m.event_id = c.event_id " + " AND m.room_id = c.room_id " + " AND m.user_id = c.state_key" + " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?" ) - return [r["user_id"] for r in rows] + txn.execute(sql, (room_id, Membership.JOIN,)) + return [r[0] for r in txn] return self.runInteraction("get_users_in_room", f) @cached() @@ -220,7 +231,7 @@ class RoomMemberStore(SQLBaseStore): " ON e.event_id = c.event_id" " AND m.room_id = c.room_id" " AND m.user_id = c.state_key" - " WHERE %s" + " WHERE c.type = 'm.room.member' AND %s" ) % (where_clause,) txn.execute(sql, args) @@ -248,39 +259,31 @@ class RoomMemberStore(SQLBaseStore): return results - def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None): - where_clause = "c.room_id = ?" - where_values = [room_id] - - if membership: - where_clause += " AND m.membership = ?" - where_values.append(membership) - - if user_id: - where_clause += " AND m.user_id = ?" - where_values.append(user_id) - - sql = ( - "SELECT m.* FROM room_memberships as m" - " INNER JOIN current_state_events as c" - " ON m.event_id = c.event_id " - " AND m.room_id = c.room_id " - " AND m.user_id = c.state_key" - " WHERE %(where)s" - ) % { - "where": where_clause, - } - - txn.execute(sql, where_values) - rows = self.cursor_to_dict(txn) - - return rows - - @cached(max_entries=5000) + @cachedInlineCallbacks(max_entries=500000, iterable=True) def get_rooms_for_user(self, user_id): - return self.get_rooms_for_user_where_membership_is( + """Returns a set of room_ids the user is currently joined to + """ + rooms = yield self.get_rooms_for_user_where_membership_is( user_id, membership_list=[Membership.JOIN], ) + defer.returnValue(frozenset(r.room_id for r in rooms)) + + @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True) + def get_users_who_share_room_with_user(self, user_id, cache_context): + """Returns the set of users who share a room with `user_id` + """ + room_ids = yield self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate, + ) + + user_who_share_room = set() + for room_id in room_ids: + user_ids = yield self.get_users_in_room( + room_id, on_invalidate=cache_context.invalidate, + ) + user_who_share_room.update(user_ids) + + defer.returnValue(user_who_share_room) def forget(self, user_id, room_id): """Indicate that user_id wishes to discard history for room_id.""" diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/schema/delta/40/current_state_idx.sql new file mode 100644 index 0000000000..7ffa189f39 --- /dev/null +++ b/synapse/storage/schema/delta/40/current_state_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (update_name, progress_json) VALUES + ('current_state_members_idx', '{}'); diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/schema/delta/40/device_list_streams.sql new file mode 100644 index 0000000000..54841b3843 --- /dev/null +++ b/synapse/storage/schema/delta/40/device_list_streams.sql @@ -0,0 +1,59 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Cache of remote devices. +CREATE TABLE device_lists_remote_cache ( + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + content TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id); + + +-- The last update we got for a user. Empty if we're not receiving updates for +-- that user. +CREATE TABLE device_lists_remote_extremeties ( + user_id TEXT NOT NULL, + stream_id TEXT NOT NULL +); + +CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id); + + +-- Stream of device lists updates. Includes both local and remotes +CREATE TABLE device_lists_stream ( + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL +); + +CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id); + + +-- The stream of updates to send to other servers. We keep at least one row +-- per user that was sent so that the prev_id for any new updates can be +-- calculated +CREATE TABLE device_lists_outbound_pokes ( + destination TEXT NOT NULL, + stream_id BIGINT NOT NULL, + user_id TEXT NOT NULL, + device_id TEXT NOT NULL, + sent BOOLEAN NOT NULL, + ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers +); + +CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id); +CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id); diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/schema/delta/40/event_push_summary.sql new file mode 100644 index 0000000000..3918f0b794 --- /dev/null +++ b/synapse/storage/schema/delta/40/event_push_summary.sql @@ -0,0 +1,37 @@ +/* Copyright 2017 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Aggregate of old notification counts that have been deleted out of the +-- main event_push_actions table. This count does not include those that were +-- highlights, as they remain in the event_push_actions table. +CREATE TABLE event_push_summary ( + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + notif_count BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL +); + +CREATE INDEX event_push_summary_user_rm ON event_push_summary(user_id, room_id); + + +-- The stream ordering up to which we have aggregated the event_push_actions +-- table into event_push_summary +CREATE TABLE event_push_summary_stream_ordering ( + Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row. + stream_ordering BIGINT NOT NULL, + CHECK (Lock='X') +); + +INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0); diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/schema/delta/40/pushers.sql new file mode 100644 index 0000000000..054a223f14 --- /dev/null +++ b/synapse/storage/schema/delta/40/pushers.sql @@ -0,0 +1,39 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS pushers2 ( + id BIGINT PRIMARY KEY, + user_name TEXT NOT NULL, + access_token BIGINT DEFAULT NULL, + profile_tag TEXT NOT NULL, + kind TEXT NOT NULL, + app_id TEXT NOT NULL, + app_display_name TEXT NOT NULL, + device_display_name TEXT NOT NULL, + pushkey TEXT NOT NULL, + ts BIGINT NOT NULL, + lang TEXT, + data TEXT, + last_stream_ordering INTEGER, + last_success BIGINT, + failing_since BIGINT, + UNIQUE (app_id, pushkey, user_name) +); + +INSERT INTO pushers2 SELECT * FROM PUSHERS; + +DROP TABLE PUSHERS; + +ALTER TABLE pushers2 RENAME TO pushers; diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/schema/delta/41/device_list_stream_idx.sql new file mode 100644 index 0000000000..b7bee8b692 --- /dev/null +++ b/synapse/storage/schema/delta/41/device_list_stream_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT into background_updates (update_name, progress_json) + VALUES ('device_lists_stream_idx', '{}'); diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/schema/delta/41/device_outbound_index.sql new file mode 100644 index 0000000000..62f0b9892b --- /dev/null +++ b/synapse/storage/schema/delta/41/device_outbound_index.sql @@ -0,0 +1,16 @@ +/* Copyright 2017 Vector Creations Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX device_lists_outbound_pokes_stream ON device_lists_outbound_pokes(stream_id); diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py index e1dca927d7..67d5d9969a 100644 --- a/synapse/storage/signatures.py +++ b/synapse/storage/signatures.py @@ -72,7 +72,7 @@ class SignatureStore(SQLBaseStore): " WHERE event_id = ?" ) txn.execute(query, (event_id, )) - return {k: v for k, v in txn.fetchall()} + return {k: v for k, v in txn} def _store_event_reference_hashes_txn(self, txn, events): """Store a hash for a PDU diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 7d34dd03bf..314216f039 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -14,7 +14,7 @@ # limitations under the License. from ._base import SQLBaseStore -from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks from synapse.util.caches import intern_string from synapse.storage.engines import PostgresEngine @@ -49,6 +49,7 @@ class StateStore(SQLBaseStore): STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication" STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index" + CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx" def __init__(self, hs): super(StateStore, self).__init__(hs) @@ -60,6 +61,25 @@ class StateStore(SQLBaseStore): self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state, ) + self.register_background_index_update( + self.CURRENT_STATE_INDEX_UPDATE_NAME, + index_name="current_state_events_member_index", + table="current_state_events", + columns=["state_key"], + where_clause="type='m.room.member'", + ) + + @cachedInlineCallbacks(max_entries=100000, iterable=True) + def get_current_state_ids(self, room_id): + rows = yield self._simple_select_list( + table="current_state_events", + keyvalues={"room_id": room_id}, + retcols=["event_id", "type", "state_key"], + desc="_calculate_state_delta", + ) + defer.returnValue({ + (r["type"], r["state_key"]): r["event_id"] for r in rows + }) @defer.inlineCallbacks def get_state_groups_ids(self, room_id, event_ids): @@ -70,7 +90,7 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.values()) + groups = set(event_to_groups.itervalues()) group_to_state = yield self._get_state_for_groups(groups) defer.returnValue(group_to_state) @@ -88,17 +108,18 @@ class StateStore(SQLBaseStore): state_event_map = yield self.get_events( [ - ev_id for group_ids in group_to_ids.values() - for ev_id in group_ids.values() + ev_id for group_ids in group_to_ids.itervalues() + for ev_id in group_ids.itervalues() ], get_prev_content=False ) defer.returnValue({ group: [ - state_event_map[v] for v in event_id_map.values() if v in state_event_map + state_event_map[v] for v in event_id_map.itervalues() + if v in state_event_map ] - for group, event_id_map in group_to_ids.items() + for group, event_id_map in group_to_ids.iteritems() }) def _have_persisted_state_group_txn(self, txn, state_group): @@ -116,6 +137,16 @@ class StateStore(SQLBaseStore): continue if context.current_state_ids is None: + # AFAIK, this can never happen + logger.error( + "Non-outlier event %s had current_state_ids==None", + event.event_id) + continue + + # if the event was rejected, just give it the same state as its + # predecessor. + if context.rejected: + state_groups[event.event_id] = context.prev_group continue state_groups[event.event_id] = context.state_group @@ -160,7 +191,7 @@ class StateStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in context.delta_ids.items() + for key, state_id in context.delta_ids.iteritems() ], ) else: @@ -175,7 +206,7 @@ class StateStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in context.current_state_ids.items() + for key, state_id in context.current_state_ids.iteritems() ], ) @@ -187,7 +218,7 @@ class StateStore(SQLBaseStore): "state_group": state_group_id, "event_id": event_id, } - for event_id, state_group_id in state_groups.items() + for event_id, state_group_id in state_groups.iteritems() ], ) @@ -232,58 +263,6 @@ class StateStore(SQLBaseStore): return count - @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key=""): - if event_type and state_key is not None: - result = yield self.get_current_state_for_key( - room_id, event_type, state_key - ) - defer.returnValue(result) - - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? " - ) - - if event_type and state_key is not None: - sql += " AND type = ? AND state_key = ? " - args = (room_id, event_type, state_key) - elif event_type: - sql += " AND type = ?" - args = (room_id, event_type) - else: - args = (room_id, ) - - txn.execute(sql, args) - results = txn.fetchall() - - return [r[0] for r in results] - - event_ids = yield self.runInteraction("get_current_state", f) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @defer.inlineCallbacks - def get_current_state_for_key(self, room_id, event_type, state_key): - event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key) - events = yield self._get_events(event_ids, get_prev_content=False) - defer.returnValue(events) - - @cached(num_args=3) - def _get_current_state_for_key(self, room_id, event_type, state_key): - def f(txn): - sql = ( - "SELECT event_id FROM current_state_events" - " WHERE room_id = ? AND type = ? AND state_key = ?" - ) - - args = (room_id, event_type, state_key) - txn.execute(sql, args) - results = txn.fetchall() - return [r[0] for r in results] - return self.runInteraction("get_current_state_for_key", f) - @cached(num_args=2, max_entries=100000, iterable=True) def _get_state_group_from_group(self, group, types): raise NotImplementedError() @@ -363,10 +342,10 @@ class StateStore(SQLBaseStore): args.extend(where_args) txn.execute(sql % (where_clause,), args) - rows = self.cursor_to_dict(txn) - for row in rows: - key = (row["type"], row["state_key"]) - results[group][key] = row["event_id"] + for row in txn: + typ, state_key, event_id = row + key = (typ, state_key) + results[group][key] = event_id else: if types is not None: where_clause = "AND (%s)" % ( @@ -395,12 +374,11 @@ class StateStore(SQLBaseStore): " WHERE state_group = ? %s" % (where_clause,), args ) - rows = txn.fetchall() - results[group].update({ - (typ, state_key): event_id - for typ, state_key, event_id in rows + results[group].update( + ((typ, state_key), event_id) + for typ, state_key, event_id in txn if (typ, state_key) not in results[group] - }) + ) # If the lengths match then we must have all the types, # so no need to go walk further down the tree. @@ -437,37 +415,49 @@ class StateStore(SQLBaseStore): event_ids, ) - groups = set(event_to_groups.values()) + groups = set(event_to_groups.itervalues()) group_to_state = yield self._get_state_for_groups(groups, types) state_event_map = yield self.get_events( - [ev_id for sd in group_to_state.values() for ev_id in sd.values()], + [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()], get_prev_content=False ) event_to_state = { event_id: { k: state_event_map[v] - for k, v in group_to_state[group].items() + for k, v in group_to_state[group].iteritems() if v in state_event_map } - for event_id, group in event_to_groups.items() + for event_id, group in event_to_groups.iteritems() } defer.returnValue({event: event_to_state[event] for event in event_ids}) @defer.inlineCallbacks - def get_state_ids_for_events(self, event_ids, types): + def get_state_ids_for_events(self, event_ids, types=None): + """ + Get the state dicts corresponding to a list of events + + Args: + event_ids(list(str)): events whose state should be returned + types(list[(str, str)]|None): List of (type, state_key) tuples + which are used to filter the state fetched. May be None, which + matches any key + + Returns: + A deferred dict from event_id -> (type, state_key) -> state_event + """ event_to_groups = yield self._get_state_group_for_events( event_ids, ) - groups = set(event_to_groups.values()) + groups = set(event_to_groups.itervalues()) group_to_state = yield self._get_state_for_groups(groups, types) event_to_state = { event_id: group_to_state[group] - for event_id, group in event_to_groups.items() + for event_id, group in event_to_groups.iteritems() } defer.returnValue({event: event_to_state[event] for event in event_ids}) @@ -579,7 +569,7 @@ class StateStore(SQLBaseStore): got_all = not (missing_types or types is None) return { - k: v for k, v in state_dict_ids.items() + k: v for k, v in state_dict_ids.iteritems() if include(k[0], k[1]) }, missing_types, got_all @@ -638,7 +628,7 @@ class StateStore(SQLBaseStore): # Now we want to update the cache with all the things we fetched # from the database. - for group, group_state_dict in group_to_state_dict.items(): + for group, group_state_dict in group_to_state_dict.iteritems(): if types: # We delibrately put key -> None mappings into the cache to # cache absence of the key, on the assumption that if we've @@ -653,10 +643,10 @@ class StateStore(SQLBaseStore): else: state_dict = results[group] - state_dict.update({ - (intern_string(k[0]), intern_string(k[1])): v - for k, v in group_state_dict.items() - }) + state_dict.update( + ((intern_string(k[0]), intern_string(k[1])), v) + for k, v in group_state_dict.iteritems() + ) self._state_group_cache.update( cache_seq_num, @@ -667,10 +657,10 @@ class StateStore(SQLBaseStore): # Remove all the entries with None values. The None values were just # used for bookkeeping in the cache. - for group, state_dict in results.items(): + for group, state_dict in results.iteritems(): results[group] = { key: event_id - for key, event_id in state_dict.items() + for key, event_id in state_dict.iteritems() if event_id } @@ -759,7 +749,7 @@ class StateStore(SQLBaseStore): # of keys delta_state = { - key: value for key, value in curr_state.items() + key: value for key, value in curr_state.iteritems() if prev_state.get(key, None) != value } @@ -799,7 +789,7 @@ class StateStore(SQLBaseStore): "state_key": key[1], "event_id": state_id, } - for key, state_id in delta_state.items() + for key, state_id in delta_state.iteritems() ], ) diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 2dc24951c4..dddd5fc0e7 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -244,6 +244,20 @@ class StreamStore(SQLBaseStore): defer.returnValue(results) + def get_rooms_that_changed(self, room_ids, from_key): + """Given a list of rooms and a token, return rooms where there may have + been changes. + + Args: + room_ids (list) + from_key (str): The room_key portion of a StreamToken + """ + from_key = RoomStreamToken.parse_stream_token(from_key).stream + return set( + room_id for room_id in room_ids + if self._events_stream_cache.has_entity_changed(room_id, from_key) + ) + @defer.inlineCallbacks def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0, order='DESC'): @@ -815,3 +829,6 @@ class StreamStore(SQLBaseStore): updatevalues={"stream_id": stream_id}, desc="update_federation_out_pos", ) + + def has_room_changed_since(self, room_id, stream_id): + return self._events_stream_cache.has_entity_changed(room_id, stream_id) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 5a2c1aa59b..bff73f3f04 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -95,7 +95,7 @@ class TagsStore(SQLBaseStore): for stream_id, user_id, room_id in tag_ids: txn.execute(sql, (user_id, room_id)) tags = [] - for tag, content in txn.fetchall(): + for tag, content in txn: tags.append(json.dumps(tag) + ":" + content) tag_json = "{" + ",".join(tags) + "}" results.append((stream_id, user_id, room_id, tag_json)) @@ -132,7 +132,7 @@ class TagsStore(SQLBaseStore): " WHERE user_id = ? AND stream_id > ?" ) txn.execute(sql, (user_id, stream_id)) - room_ids = [row[0] for row in txn.fetchall()] + room_ids = [row[0] for row in txn] return room_ids changed = self._account_data_stream_cache.has_entity_changed( diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 46cf93ff87..95031dc9ec 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -30,6 +30,17 @@ class IdGenerator(object): def _load_current_id(db_conn, table, column, step=1): + """ + + Args: + db_conn (object): + table (str): + column (str): + step (int): + + Returns: + int + """ cur = db_conn.cursor() if step == 1: cur.execute("SELECT MAX(%s) FROM %s" % (column, table,)) @@ -131,6 +142,9 @@ class StreamIdGenerator(object): def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or equal to it have been successfully persisted. + + Returns: + int """ with self._lock: if self._unfinished_ids: |