diff options
Diffstat (limited to 'synapse/storage/__init__.py')
-rw-r--r-- | synapse/storage/__init__.py | 644 |
1 files changed, 147 insertions, 497 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 4b16f445d6..0cc14fb692 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -14,13 +14,12 @@ # limitations under the License. from twisted.internet import defer - -from synapse.util.logutils import log_function -from synapse.api.constants import EventTypes - -from .appservice import ApplicationServiceStore +from .appservice import ( + ApplicationServiceStore, ApplicationServiceTransactionStore +) +from ._base import Cache from .directory import DirectoryStore -from .feedback import FeedbackStore +from .events import EventsStore from .presence import PresenceStore from .profile import ProfileStore from .registration import RegistrationStore @@ -39,11 +38,6 @@ from .state import StateStore from .signatures import SignatureStore from .filtering import FilteringStore -from syutil.base64util import decode_base64 -from syutil.jsonutil import encode_canonical_json - -from synapse.crypto.event_signing import compute_event_reference_hash - import fnmatch import imp @@ -57,20 +51,18 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 14 +SCHEMA_VERSION = 17 dir_path = os.path.abspath(os.path.dirname(__file__)) - -class _RollbackButIsFineException(Exception): - """ This exception is used to rollback a transaction without implying - something went wrong. - """ - pass +# Number of msec of granularity to store the user IP 'last seen' time. Smaller +# times give more inserts into the database even for readonly API hits +# 120 seconds == 2 minutes +LAST_SEEN_GRANULARITY = 120*1000 class DataStore(RoomMemberStore, RoomStore, - RegistrationStore, StreamStore, ProfileStore, FeedbackStore, + RegistrationStore, StreamStore, ProfileStore, PresenceStore, TransactionStore, DirectoryStore, KeyStore, StateStore, SignatureStore, ApplicationServiceStore, @@ -79,7 +71,9 @@ class DataStore(RoomMemberStore, RoomStore, RejectionsStore, FilteringStore, PusherStore, - PushRuleStore + PushRuleStore, + ApplicationServiceTransactionStore, + EventsStore, ): def __init__(self, hs): @@ -89,474 +83,53 @@ class DataStore(RoomMemberStore, RoomStore, self.min_token_deferred = self._get_min_token() self.min_token = None - @defer.inlineCallbacks - @log_function - def persist_event(self, event, context, backfilled=False, - is_new_state=True, current_state=None): - stream_ordering = None - if backfilled: - if not self.min_token_deferred.called: - yield self.min_token_deferred - self.min_token -= 1 - stream_ordering = self.min_token - - try: - yield self.runInteraction( - "persist_event", - self._persist_event_txn, - event=event, - context=context, - backfilled=backfilled, - stream_ordering=stream_ordering, - is_new_state=is_new_state, - current_state=current_state, - ) - except _RollbackButIsFineException: - pass - - @defer.inlineCallbacks - def get_event(self, event_id, check_redacted=True, - get_prev_content=False, allow_rejected=False, - allow_none=False): - """Get an event from the database by event_id. - - Args: - event_id (str): The event_id of the event to fetch - check_redacted (bool): If True, check if event has been redacted - and redact it. - get_prev_content (bool): If True and event is a state event, - include the previous states content in the unsigned field. - allow_rejected (bool): If True return rejected events. - allow_none (bool): If True, return None if no event found, if - False throw an exception. - - Returns: - Deferred : A FrozenEvent. - """ - event = yield self.runInteraction( - "get_event", self._get_event_txn, - event_id, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - allow_rejected=allow_rejected, - ) - - if not event and not allow_none: - raise RuntimeError("Could not find event %s" % (event_id,)) - - defer.returnValue(event) - - @log_function - def _persist_event_txn(self, txn, event, context, backfilled, - stream_ordering=None, is_new_state=True, - current_state=None): - - # Remove the any existing cache entries for the event_id - self._get_event_cache.pop(event.event_id) - - # We purposefully do this first since if we include a `current_state` - # key, we *want* to update the `current_state_events` table - if current_state: - txn.execute( - "DELETE FROM current_state_events WHERE room_id = ?", - (event.room_id,) - ) - - for s in current_state: - self._simple_insert_txn( - txn, - "current_state_events", - { - "event_id": s.event_id, - "room_id": s.room_id, - "type": s.type, - "state_key": s.state_key, - }, - or_replace=True, - ) - - if event.is_state() and is_new_state: - if not backfilled and not context.rejected: - self._simple_insert_txn( - txn, - table="state_forward_extremities", - values={ - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - or_replace=True, - ) - - for prev_state_id, _ in event.prev_state: - self._simple_delete_txn( - txn, - table="state_forward_extremities", - keyvalues={ - "event_id": prev_state_id, - } - ) - - outlier = event.internal_metadata.is_outlier() - - if not outlier: - self._store_state_groups_txn(txn, event, context) - - self._update_min_depth_for_room_txn( - txn, - event.room_id, - event.depth - ) - - self._handle_prev_events( - txn, - outlier=outlier, - event_id=event.event_id, - prev_events=event.prev_events, - room_id=event.room_id, - ) - - have_persisted = self._simple_select_one_onecol_txn( - txn, - table="event_json", - keyvalues={"event_id": event.event_id}, - retcol="event_id", - allow_none=True, - ) - - metadata_json = encode_canonical_json( - event.internal_metadata.get_dict() - ) - - # If we have already persisted this event, we don't need to do any - # more processing. - # The processing above must be done on every call to persist event, - # since they might not have happened on previous calls. For example, - # if we are persisting an event that we had persisted as an outlier, - # but is no longer one. - if have_persisted: - if not outlier: - sql = ( - "UPDATE event_json SET internal_metadata = ?" - " WHERE event_id = ?" - ) - txn.execute( - sql, - (metadata_json.decode("UTF-8"), event.event_id,) - ) - - sql = ( - "UPDATE events SET outlier = 0" - " WHERE event_id = ?" - ) - txn.execute( - sql, - (event.event_id,) - ) - return - - if event.type == EventTypes.Member: - self._store_room_member_txn(txn, event) - elif event.type == EventTypes.Feedback: - self._store_feedback_txn(txn, event) - elif event.type == EventTypes.Name: - self._store_room_name_txn(txn, event) - elif event.type == EventTypes.Topic: - self._store_room_topic_txn(txn, event) - elif event.type == EventTypes.Redaction: - self._store_redaction(txn, event) - - event_dict = { - k: v - for k, v in event.get_dict().items() - if k not in [ - "redacted", - "redacted_because", - ] - } - - self._simple_insert_txn( - txn, - table="event_json", - values={ - "event_id": event.event_id, - "room_id": event.room_id, - "internal_metadata": metadata_json.decode("UTF-8"), - "json": encode_canonical_json(event_dict).decode("UTF-8"), - }, - or_replace=True, - ) - - content = encode_canonical_json( - event.content - ).decode("UTF-8") - - vals = { - "topological_ordering": event.depth, - "event_id": event.event_id, - "type": event.type, - "room_id": event.room_id, - "content": content, - "processed": True, - "outlier": outlier, - "depth": event.depth, - } - - if stream_ordering is not None: - vals["stream_ordering"] = stream_ordering - - unrec = { - k: v - for k, v in event.get_dict().items() - if k not in vals.keys() and k not in [ - "redacted", - "redacted_because", - "signatures", - "hashes", - "prev_events", - ] - } - - vals["unrecognized_keys"] = encode_canonical_json( - unrec - ).decode("UTF-8") - - try: - self._simple_insert_txn( - txn, - "events", - vals, - or_replace=(not outlier), - or_ignore=bool(outlier), - ) - except: - logger.warn( - "Failed to persist, probably duplicate: %s", - event.event_id, - exc_info=True, - ) - raise _RollbackButIsFineException("_persist_event") - - if context.rejected: - self._store_rejections_txn(txn, event.event_id, context.rejected) - - if event.is_state(): - vals = { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - } - - # TODO: How does this work with backfilling? - if hasattr(event, "replaces_state"): - vals["prev_state"] = event.replaces_state - - self._simple_insert_txn( - txn, - "state_events", - vals, - or_replace=True, - ) - - if is_new_state and not context.rejected: - self._simple_insert_txn( - txn, - "current_state_events", - { - "event_id": event.event_id, - "room_id": event.room_id, - "type": event.type, - "state_key": event.state_key, - }, - or_replace=True, - ) - - for e_id, h in event.prev_state: - self._simple_insert_txn( - txn, - table="event_edges", - values={ - "event_id": event.event_id, - "prev_event_id": e_id, - "room_id": event.room_id, - "is_state": 1, - }, - or_ignore=True, - ) - - for hash_alg, hash_base64 in event.hashes.items(): - hash_bytes = decode_base64(hash_base64) - self._store_event_content_hash_txn( - txn, event.event_id, hash_alg, hash_bytes, - ) - - for prev_event_id, prev_hashes in event.prev_events: - for alg, hash_base64 in prev_hashes.items(): - hash_bytes = decode_base64(hash_base64) - self._store_prev_event_hash_txn( - txn, event.event_id, prev_event_id, alg, hash_bytes - ) - - for auth_id, _ in event.auth_events: - self._simple_insert_txn( - txn, - table="event_auth", - values={ - "event_id": event.event_id, - "room_id": event.room_id, - "auth_id": auth_id, - }, - or_ignore=True, - ) - - (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event) - self._store_event_reference_hash_txn( - txn, event.event_id, ref_alg, ref_hash_bytes - ) - - def _store_redaction(self, txn, event): - # invalidate the cache for the redacted event - self._get_event_cache.pop(event.redacts) - txn.execute( - "INSERT OR IGNORE INTO redactions " - "(event_id, redacts) VALUES (?,?)", - (event.event_id, event.redacts) + self.client_ip_last_seen = Cache( + name="client_ip_last_seen", + keylen=4, ) @defer.inlineCallbacks - def get_current_state(self, room_id, event_type=None, state_key=""): - del_sql = ( - "SELECT event_id FROM redactions WHERE redacts = e.event_id " - "LIMIT 1" - ) - - sql = ( - "SELECT e.*, (%(redacted)s) AS redacted FROM events as e " - "INNER JOIN current_state_events as c ON e.event_id = c.event_id " - "INNER JOIN state_events as s ON e.event_id = s.event_id " - "WHERE c.room_id = ? " - ) % { - "redacted": del_sql, - } - - if event_type and state_key is not None: - sql += " AND s.type = ? AND s.state_key = ? " - args = (room_id, event_type, state_key) - elif event_type: - sql += " AND s.type = ?" - args = (room_id, event_type) - else: - args = (room_id, ) - - results = yield self._execute_and_decode("get_current_state", sql, *args) - - events = yield self._parse_events(results) - defer.returnValue(events) - - @defer.inlineCallbacks - def get_room_name_and_aliases(self, room_id): - del_sql = ( - "SELECT event_id FROM redactions WHERE redacts = e.event_id " - "LIMIT 1" - ) - - sql = ( - "SELECT e.*, (%(redacted)s) AS redacted FROM events as e " - "INNER JOIN current_state_events as c ON e.event_id = c.event_id " - "INNER JOIN state_events as s ON e.event_id = s.event_id " - "WHERE c.room_id = ? " - ) % { - "redacted": del_sql, - } - - sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')" - sql += " OR s.type = 'm.room.aliases')" - args = (room_id,) - - results = yield self._execute_and_decode("get_current_state", sql, *args) - - events = yield self._parse_events(results) - - name = None - aliases = [] - - for e in events: - if e.type == 'm.room.name': - if 'name' in e.content: - name = e.content['name'] - elif e.type == 'm.room.aliases': - if 'aliases' in e.content: - aliases.extend(e.content['aliases']) - - defer.returnValue((name, aliases)) - - @defer.inlineCallbacks - def _get_min_token(self): - row = yield self._execute( - "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events" - ) + def insert_client_ip(self, user, access_token, device_id, ip, user_agent): + now = int(self._clock.time_msec()) + key = (user.to_string(), access_token, device_id, ip) - self.min_token = row[0][0] if row and row[0] and row[0][0] else -1 - self.min_token = min(self.min_token, -1) + try: + last_seen = self.client_ip_last_seen.get(*key) + except KeyError: + last_seen = None - logger.debug("min_token is: %s", self.min_token) + # Rate-limited inserts + if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: + defer.returnValue(None) - defer.returnValue(self.min_token) + self.client_ip_last_seen.prefill(*key + (now,)) - def insert_client_ip(self, user, access_token, device_id, ip, user_agent): - return self._simple_insert( + # It's safe not to lock here: a) no unique constraint, + # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely + yield self._simple_upsert( "user_ips", - { - "user": user.to_string(), + keyvalues={ + "user_id": user.to_string(), "access_token": access_token, - "device_id": device_id, "ip": ip, "user_agent": user_agent, - "last_seen": int(self._clock.time_msec()), - } + }, + values={ + "device_id": device_id, + "last_seen": now, + }, + desc="insert_client_ip", + lock=False, ) def get_user_ip_and_agents(self, user): return self._simple_select_list( table="user_ips", - keyvalues={"user": user.to_string()}, + keyvalues={"user_id": user.to_string()}, retcols=[ "device_id", "access_token", "ip", "user_agent", "last_seen" ], - ) - - def have_events(self, event_ids): - """Given a list of event ids, check if we have already processed them. - - Returns: - dict: Has an entry for each event id we already have seen. Maps to - the rejected reason string if we rejected the event, else maps to - None. - """ - if not event_ids: - return defer.succeed({}) - - def f(txn): - sql = ( - "SELECT e.event_id, reason FROM events as e " - "LEFT JOIN rejections as r ON e.event_id = r.event_id " - "WHERE e.event_id = ?" - ) - - res = {} - for event_id in event_ids: - txn.execute(sql, (event_id,)) - row = txn.fetchone() - if row: - _, rejected = row - res[event_id] = rejected - - return res - - return self.runInteraction( - "have_events", f, + desc="get_user_ip_and_agents", ) @@ -580,21 +153,23 @@ class UpgradeDatabaseException(PrepareDatabaseException): pass -def prepare_database(db_conn): +def prepare_database(db_conn, database_engine): """Prepares a database for usage. Will either create all necessary tables or upgrade from an older schema version. """ try: cur = db_conn.cursor() - version_info = _get_or_create_schema_state(cur) + version_info = _get_or_create_schema_state(cur, database_engine) if version_info: user_version, delta_files, upgraded = version_info - _upgrade_existing_database(cur, user_version, delta_files, upgraded) + _upgrade_existing_database( + cur, user_version, delta_files, upgraded, database_engine + ) else: - _setup_new_database(cur) + _setup_new_database(cur, database_engine) - cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) + # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,)) cur.close() db_conn.commit() @@ -603,7 +178,7 @@ def prepare_database(db_conn): raise -def _setup_new_database(cur): +def _setup_new_database(cur, database_engine): """Sets up the database by finding a base set of "full schemas" and then applying any necessary deltas. @@ -657,31 +232,30 @@ def _setup_new_database(cur): directory_entries = os.listdir(sql_dir) - sql_script = "BEGIN TRANSACTION;\n" for filename in fnmatch.filter(directory_entries, "*.sql"): sql_loc = os.path.join(sql_dir, filename) logger.debug("Applying schema %s", sql_loc) - sql_script += read_schema(sql_loc) - sql_script += "\n" - sql_script += "COMMIT TRANSACTION;" - cur.executescript(sql_script) + executescript(cur, sql_loc) cur.execute( - "INSERT OR REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", - (max_current_ver, False) + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)" + ), + (max_current_ver, False,) ) _upgrade_existing_database( cur, current_version=max_current_ver, applied_delta_files=[], - upgraded=False + upgraded=False, + database_engine=database_engine, ) def _upgrade_existing_database(cur, current_version, applied_delta_files, - upgraded): + upgraded, database_engine): """Upgrades an existing database. Delta files can either be SQL stored in *.sql files, or python modules @@ -737,6 +311,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, if not upgraded: start_ver += 1 + logger.debug("applied_delta_files: %s", applied_delta_files) + for v in range(start_ver, SCHEMA_VERSION + 1): logger.debug("Upgrading schema to v%d", v) @@ -753,6 +329,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, directory_entries.sort() for file_name in directory_entries: relative_path = os.path.join(str(v), file_name) + logger.debug("Found file: %s", relative_path) if relative_path in applied_delta_files: continue @@ -774,9 +351,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, module.run_upgrade(cur) elif ext == ".sql": # A plain old .sql file, just read and execute it - delta_schema = read_schema(absolute_path) logger.debug("Applying schema %s", relative_path) - cur.executescript(delta_schema) + executescript(cur, absolute_path) else: # Not a valid delta file. logger.warn( @@ -788,24 +364,83 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files, # Mark as done. cur.execute( - "INSERT INTO applied_schema_deltas (version, file)" - " VALUES (?,?)", + database_engine.convert_param_style( + "INSERT INTO applied_schema_deltas (version, file)" + " VALUES (?,?)", + ), (v, relative_path) ) + cur.execute("DELETE FROM schema_version") cur.execute( - "INSERT OR REPLACE INTO schema_version (version, upgraded)" - " VALUES (?,?)", + database_engine.convert_param_style( + "INSERT INTO schema_version (version, upgraded)" + " VALUES (?,?)", + ), (v, True) ) -def _get_or_create_schema_state(txn): +def get_statements(f): + statement_buffer = "" + in_comment = False # If we're in a /* ... */ style comment + + for line in f: + line = line.strip() + + if in_comment: + # Check if this line contains an end to the comment + comments = line.split("*/", 1) + if len(comments) == 1: + continue + line = comments[1] + in_comment = False + + # Remove inline block comments + line = re.sub(r"/\*.*\*/", " ", line) + + # Does this line start a comment? + comments = line.split("/*", 1) + if len(comments) > 1: + line = comments[0] + in_comment = True + + # Deal with line comments + line = line.split("--", 1)[0] + line = line.split("//", 1)[0] + + # Find *all* semicolons. We need to treat first and last entry + # specially. + statements = line.split(";") + + # We must prepend statement_buffer to the first statement + first_statement = "%s %s" % ( + statement_buffer.strip(), + statements[0].strip() + ) + statements[0] = first_statement + + # Every entry, except the last, is a full statement + for statement in statements[:-1]: + yield statement.strip() + + # The last entry did *not* end in a semicolon, so we store it for the + # next semicolon we find + statement_buffer = statements[-1].strip() + + +def executescript(txn, schema_path): + with open(schema_path, 'r') as f: + for statement in get_statements(f): + txn.execute(statement) + + +def _get_or_create_schema_state(txn, database_engine): + # Bluntly try creating the schema_version tables. schema_path = os.path.join( dir_path, "schema", "schema_version.sql", ) - create_schema = read_schema(schema_path) - txn.executescript(create_schema) + executescript(txn, schema_path) txn.execute("SELECT version, upgraded FROM schema_version") row = txn.fetchone() @@ -814,10 +449,13 @@ def _get_or_create_schema_state(txn): if current_version: txn.execute( - "SELECT file FROM applied_schema_deltas WHERE version >= ?", + database_engine.convert_param_style( + "SELECT file FROM applied_schema_deltas WHERE version >= ?" + ), (current_version,) ) - return current_version, txn.fetchall(), upgraded + applied_deltas = [d for d, in txn.fetchall()] + return current_version, applied_deltas, upgraded return None @@ -849,7 +487,19 @@ def prepare_sqlite3_database(db_conn): if row and row[0]: db_conn.execute( - "INSERT OR REPLACE INTO schema_version (version, upgraded)" + "REPLACE INTO schema_version (version, upgraded)" " VALUES (?,?)", (row[0], False) ) + + +def are_all_users_on_domain(txn, database_engine, domain): + sql = database_engine.convert_param_style( + "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?" + ) + pat = "%:" + domain + txn.execute(sql, (pat,)) + num_not_matching = txn.fetchall()[0][0] + if num_not_matching == 0: + return True + return False |