diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4b16f445d6..0cc14fb692 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,13 +14,12 @@
# limitations under the License.
from twisted.internet import defer
-
-from synapse.util.logutils import log_function
-from synapse.api.constants import EventTypes
-
-from .appservice import ApplicationServiceStore
+from .appservice import (
+ ApplicationServiceStore, ApplicationServiceTransactionStore
+)
+from ._base import Cache
from .directory import DirectoryStore
-from .feedback import FeedbackStore
+from .events import EventsStore
from .presence import PresenceStore
from .profile import ProfileStore
from .registration import RegistrationStore
@@ -39,11 +38,6 @@ from .state import StateStore
from .signatures import SignatureStore
from .filtering import FilteringStore
-from syutil.base64util import decode_base64
-from syutil.jsonutil import encode_canonical_json
-
-from synapse.crypto.event_signing import compute_event_reference_hash
-
import fnmatch
import imp
@@ -57,20 +51,18 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 14
+SCHEMA_VERSION = 17
dir_path = os.path.abspath(os.path.dirname(__file__))
-
-class _RollbackButIsFineException(Exception):
- """ This exception is used to rollback a transaction without implying
- something went wrong.
- """
- pass
+# Number of msec of granularity to store the user IP 'last seen' time. Smaller
+# times give more inserts into the database even for readonly API hits
+# 120 seconds == 2 minutes
+LAST_SEEN_GRANULARITY = 120*1000
class DataStore(RoomMemberStore, RoomStore,
- RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
+ RegistrationStore, StreamStore, ProfileStore,
PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
ApplicationServiceStore,
@@ -79,7 +71,9 @@ class DataStore(RoomMemberStore, RoomStore,
RejectionsStore,
FilteringStore,
PusherStore,
- PushRuleStore
+ PushRuleStore,
+ ApplicationServiceTransactionStore,
+ EventsStore,
):
def __init__(self, hs):
@@ -89,474 +83,53 @@ class DataStore(RoomMemberStore, RoomStore,
self.min_token_deferred = self._get_min_token()
self.min_token = None
- @defer.inlineCallbacks
- @log_function
- def persist_event(self, event, context, backfilled=False,
- is_new_state=True, current_state=None):
- stream_ordering = None
- if backfilled:
- if not self.min_token_deferred.called:
- yield self.min_token_deferred
- self.min_token -= 1
- stream_ordering = self.min_token
-
- try:
- yield self.runInteraction(
- "persist_event",
- self._persist_event_txn,
- event=event,
- context=context,
- backfilled=backfilled,
- stream_ordering=stream_ordering,
- is_new_state=is_new_state,
- current_state=current_state,
- )
- except _RollbackButIsFineException:
- pass
-
- @defer.inlineCallbacks
- def get_event(self, event_id, check_redacted=True,
- get_prev_content=False, allow_rejected=False,
- allow_none=False):
- """Get an event from the database by event_id.
-
- Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
- False throw an exception.
-
- Returns:
- Deferred : A FrozenEvent.
- """
- event = yield self.runInteraction(
- "get_event", self._get_event_txn,
- event_id,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- if not event and not allow_none:
- raise RuntimeError("Could not find event %s" % (event_id,))
-
- defer.returnValue(event)
-
- @log_function
- def _persist_event_txn(self, txn, event, context, backfilled,
- stream_ordering=None, is_new_state=True,
- current_state=None):
-
- # Remove the any existing cache entries for the event_id
- self._get_event_cache.pop(event.event_id)
-
- # We purposefully do this first since if we include a `current_state`
- # key, we *want* to update the `current_state_events` table
- if current_state:
- txn.execute(
- "DELETE FROM current_state_events WHERE room_id = ?",
- (event.room_id,)
- )
-
- for s in current_state:
- self._simple_insert_txn(
- txn,
- "current_state_events",
- {
- "event_id": s.event_id,
- "room_id": s.room_id,
- "type": s.type,
- "state_key": s.state_key,
- },
- or_replace=True,
- )
-
- if event.is_state() and is_new_state:
- if not backfilled and not context.rejected:
- self._simple_insert_txn(
- txn,
- table="state_forward_extremities",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- },
- or_replace=True,
- )
-
- for prev_state_id, _ in event.prev_state:
- self._simple_delete_txn(
- txn,
- table="state_forward_extremities",
- keyvalues={
- "event_id": prev_state_id,
- }
- )
-
- outlier = event.internal_metadata.is_outlier()
-
- if not outlier:
- self._store_state_groups_txn(txn, event, context)
-
- self._update_min_depth_for_room_txn(
- txn,
- event.room_id,
- event.depth
- )
-
- self._handle_prev_events(
- txn,
- outlier=outlier,
- event_id=event.event_id,
- prev_events=event.prev_events,
- room_id=event.room_id,
- )
-
- have_persisted = self._simple_select_one_onecol_txn(
- txn,
- table="event_json",
- keyvalues={"event_id": event.event_id},
- retcol="event_id",
- allow_none=True,
- )
-
- metadata_json = encode_canonical_json(
- event.internal_metadata.get_dict()
- )
-
- # If we have already persisted this event, we don't need to do any
- # more processing.
- # The processing above must be done on every call to persist event,
- # since they might not have happened on previous calls. For example,
- # if we are persisting an event that we had persisted as an outlier,
- # but is no longer one.
- if have_persisted:
- if not outlier:
- sql = (
- "UPDATE event_json SET internal_metadata = ?"
- " WHERE event_id = ?"
- )
- txn.execute(
- sql,
- (metadata_json.decode("UTF-8"), event.event_id,)
- )
-
- sql = (
- "UPDATE events SET outlier = 0"
- " WHERE event_id = ?"
- )
- txn.execute(
- sql,
- (event.event_id,)
- )
- return
-
- if event.type == EventTypes.Member:
- self._store_room_member_txn(txn, event)
- elif event.type == EventTypes.Feedback:
- self._store_feedback_txn(txn, event)
- elif event.type == EventTypes.Name:
- self._store_room_name_txn(txn, event)
- elif event.type == EventTypes.Topic:
- self._store_room_topic_txn(txn, event)
- elif event.type == EventTypes.Redaction:
- self._store_redaction(txn, event)
-
- event_dict = {
- k: v
- for k, v in event.get_dict().items()
- if k not in [
- "redacted",
- "redacted_because",
- ]
- }
-
- self._simple_insert_txn(
- txn,
- table="event_json",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "internal_metadata": metadata_json.decode("UTF-8"),
- "json": encode_canonical_json(event_dict).decode("UTF-8"),
- },
- or_replace=True,
- )
-
- content = encode_canonical_json(
- event.content
- ).decode("UTF-8")
-
- vals = {
- "topological_ordering": event.depth,
- "event_id": event.event_id,
- "type": event.type,
- "room_id": event.room_id,
- "content": content,
- "processed": True,
- "outlier": outlier,
- "depth": event.depth,
- }
-
- if stream_ordering is not None:
- vals["stream_ordering"] = stream_ordering
-
- unrec = {
- k: v
- for k, v in event.get_dict().items()
- if k not in vals.keys() and k not in [
- "redacted",
- "redacted_because",
- "signatures",
- "hashes",
- "prev_events",
- ]
- }
-
- vals["unrecognized_keys"] = encode_canonical_json(
- unrec
- ).decode("UTF-8")
-
- try:
- self._simple_insert_txn(
- txn,
- "events",
- vals,
- or_replace=(not outlier),
- or_ignore=bool(outlier),
- )
- except:
- logger.warn(
- "Failed to persist, probably duplicate: %s",
- event.event_id,
- exc_info=True,
- )
- raise _RollbackButIsFineException("_persist_event")
-
- if context.rejected:
- self._store_rejections_txn(txn, event.event_id, context.rejected)
-
- if event.is_state():
- vals = {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- }
-
- # TODO: How does this work with backfilling?
- if hasattr(event, "replaces_state"):
- vals["prev_state"] = event.replaces_state
-
- self._simple_insert_txn(
- txn,
- "state_events",
- vals,
- or_replace=True,
- )
-
- if is_new_state and not context.rejected:
- self._simple_insert_txn(
- txn,
- "current_state_events",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "type": event.type,
- "state_key": event.state_key,
- },
- or_replace=True,
- )
-
- for e_id, h in event.prev_state:
- self._simple_insert_txn(
- txn,
- table="event_edges",
- values={
- "event_id": event.event_id,
- "prev_event_id": e_id,
- "room_id": event.room_id,
- "is_state": 1,
- },
- or_ignore=True,
- )
-
- for hash_alg, hash_base64 in event.hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_event_content_hash_txn(
- txn, event.event_id, hash_alg, hash_bytes,
- )
-
- for prev_event_id, prev_hashes in event.prev_events:
- for alg, hash_base64 in prev_hashes.items():
- hash_bytes = decode_base64(hash_base64)
- self._store_prev_event_hash_txn(
- txn, event.event_id, prev_event_id, alg, hash_bytes
- )
-
- for auth_id, _ in event.auth_events:
- self._simple_insert_txn(
- txn,
- table="event_auth",
- values={
- "event_id": event.event_id,
- "room_id": event.room_id,
- "auth_id": auth_id,
- },
- or_ignore=True,
- )
-
- (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
- self._store_event_reference_hash_txn(
- txn, event.event_id, ref_alg, ref_hash_bytes
- )
-
- def _store_redaction(self, txn, event):
- # invalidate the cache for the redacted event
- self._get_event_cache.pop(event.redacts)
- txn.execute(
- "INSERT OR IGNORE INTO redactions "
- "(event_id, redacts) VALUES (?,?)",
- (event.event_id, event.redacts)
+ self.client_ip_last_seen = Cache(
+ name="client_ip_last_seen",
+ keylen=4,
)
@defer.inlineCallbacks
- def get_current_state(self, room_id, event_type=None, state_key=""):
- del_sql = (
- "SELECT event_id FROM redactions WHERE redacts = e.event_id "
- "LIMIT 1"
- )
-
- sql = (
- "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
- "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
- "INNER JOIN state_events as s ON e.event_id = s.event_id "
- "WHERE c.room_id = ? "
- ) % {
- "redacted": del_sql,
- }
-
- if event_type and state_key is not None:
- sql += " AND s.type = ? AND s.state_key = ? "
- args = (room_id, event_type, state_key)
- elif event_type:
- sql += " AND s.type = ?"
- args = (room_id, event_type)
- else:
- args = (room_id, )
-
- results = yield self._execute_and_decode("get_current_state", sql, *args)
-
- events = yield self._parse_events(results)
- defer.returnValue(events)
-
- @defer.inlineCallbacks
- def get_room_name_and_aliases(self, room_id):
- del_sql = (
- "SELECT event_id FROM redactions WHERE redacts = e.event_id "
- "LIMIT 1"
- )
-
- sql = (
- "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
- "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
- "INNER JOIN state_events as s ON e.event_id = s.event_id "
- "WHERE c.room_id = ? "
- ) % {
- "redacted": del_sql,
- }
-
- sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
- sql += " OR s.type = 'm.room.aliases')"
- args = (room_id,)
-
- results = yield self._execute_and_decode("get_current_state", sql, *args)
-
- events = yield self._parse_events(results)
-
- name = None
- aliases = []
-
- for e in events:
- if e.type == 'm.room.name':
- if 'name' in e.content:
- name = e.content['name']
- elif e.type == 'm.room.aliases':
- if 'aliases' in e.content:
- aliases.extend(e.content['aliases'])
-
- defer.returnValue((name, aliases))
-
- @defer.inlineCallbacks
- def _get_min_token(self):
- row = yield self._execute(
- "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
- )
+ def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
+ now = int(self._clock.time_msec())
+ key = (user.to_string(), access_token, device_id, ip)
- self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
- self.min_token = min(self.min_token, -1)
+ try:
+ last_seen = self.client_ip_last_seen.get(*key)
+ except KeyError:
+ last_seen = None
- logger.debug("min_token is: %s", self.min_token)
+ # Rate-limited inserts
+ if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
+ defer.returnValue(None)
- defer.returnValue(self.min_token)
+ self.client_ip_last_seen.prefill(*key + (now,))
- def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
- return self._simple_insert(
+ # It's safe not to lock here: a) no unique constraint,
+ # b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
+ yield self._simple_upsert(
"user_ips",
- {
- "user": user.to_string(),
+ keyvalues={
+ "user_id": user.to_string(),
"access_token": access_token,
- "device_id": device_id,
"ip": ip,
"user_agent": user_agent,
- "last_seen": int(self._clock.time_msec()),
- }
+ },
+ values={
+ "device_id": device_id,
+ "last_seen": now,
+ },
+ desc="insert_client_ip",
+ lock=False,
)
def get_user_ip_and_agents(self, user):
return self._simple_select_list(
table="user_ips",
- keyvalues={"user": user.to_string()},
+ keyvalues={"user_id": user.to_string()},
retcols=[
"device_id", "access_token", "ip", "user_agent", "last_seen"
],
- )
-
- def have_events(self, event_ids):
- """Given a list of event ids, check if we have already processed them.
-
- Returns:
- dict: Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps to
- None.
- """
- if not event_ids:
- return defer.succeed({})
-
- def f(txn):
- sql = (
- "SELECT e.event_id, reason FROM events as e "
- "LEFT JOIN rejections as r ON e.event_id = r.event_id "
- "WHERE e.event_id = ?"
- )
-
- res = {}
- for event_id in event_ids:
- txn.execute(sql, (event_id,))
- row = txn.fetchone()
- if row:
- _, rejected = row
- res[event_id] = rejected
-
- return res
-
- return self.runInteraction(
- "have_events", f,
+ desc="get_user_ip_and_agents",
)
@@ -580,21 +153,23 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn):
+def prepare_database(db_conn, database_engine):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
"""
try:
cur = db_conn.cursor()
- version_info = _get_or_create_schema_state(cur)
+ version_info = _get_or_create_schema_state(cur, database_engine)
if version_info:
user_version, delta_files, upgraded = version_info
- _upgrade_existing_database(cur, user_version, delta_files, upgraded)
+ _upgrade_existing_database(
+ cur, user_version, delta_files, upgraded, database_engine
+ )
else:
- _setup_new_database(cur)
+ _setup_new_database(cur, database_engine)
- cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
+ # cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
cur.close()
db_conn.commit()
@@ -603,7 +178,7 @@ def prepare_database(db_conn):
raise
-def _setup_new_database(cur):
+def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas.
@@ -657,31 +232,30 @@ def _setup_new_database(cur):
directory_entries = os.listdir(sql_dir)
- sql_script = "BEGIN TRANSACTION;\n"
for filename in fnmatch.filter(directory_entries, "*.sql"):
sql_loc = os.path.join(sql_dir, filename)
logger.debug("Applying schema %s", sql_loc)
- sql_script += read_schema(sql_loc)
- sql_script += "\n"
- sql_script += "COMMIT TRANSACTION;"
- cur.executescript(sql_script)
+ executescript(cur, sql_loc)
cur.execute(
- "INSERT OR REPLACE INTO schema_version (version, upgraded)"
- " VALUES (?,?)",
- (max_current_ver, False)
+ database_engine.convert_param_style(
+ "INSERT INTO schema_version (version, upgraded)"
+ " VALUES (?,?)"
+ ),
+ (max_current_ver, False,)
)
_upgrade_existing_database(
cur,
current_version=max_current_ver,
applied_delta_files=[],
- upgraded=False
+ upgraded=False,
+ database_engine=database_engine,
)
def _upgrade_existing_database(cur, current_version, applied_delta_files,
- upgraded):
+ upgraded, database_engine):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -737,6 +311,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
if not upgraded:
start_ver += 1
+ logger.debug("applied_delta_files: %s", applied_delta_files)
+
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.debug("Upgrading schema to v%d", v)
@@ -753,6 +329,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
directory_entries.sort()
for file_name in directory_entries:
relative_path = os.path.join(str(v), file_name)
+ logger.debug("Found file: %s", relative_path)
if relative_path in applied_delta_files:
continue
@@ -774,9 +351,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module.run_upgrade(cur)
elif ext == ".sql":
# A plain old .sql file, just read and execute it
- delta_schema = read_schema(absolute_path)
logger.debug("Applying schema %s", relative_path)
- cur.executescript(delta_schema)
+ executescript(cur, absolute_path)
else:
# Not a valid delta file.
logger.warn(
@@ -788,24 +364,83 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
# Mark as done.
cur.execute(
- "INSERT INTO applied_schema_deltas (version, file)"
- " VALUES (?,?)",
+ database_engine.convert_param_style(
+ "INSERT INTO applied_schema_deltas (version, file)"
+ " VALUES (?,?)",
+ ),
(v, relative_path)
)
+ cur.execute("DELETE FROM schema_version")
cur.execute(
- "INSERT OR REPLACE INTO schema_version (version, upgraded)"
- " VALUES (?,?)",
+ database_engine.convert_param_style(
+ "INSERT INTO schema_version (version, upgraded)"
+ " VALUES (?,?)",
+ ),
(v, True)
)
-def _get_or_create_schema_state(txn):
+def get_statements(f):
+ statement_buffer = ""
+ in_comment = False # If we're in a /* ... */ style comment
+
+ for line in f:
+ line = line.strip()
+
+ if in_comment:
+ # Check if this line contains an end to the comment
+ comments = line.split("*/", 1)
+ if len(comments) == 1:
+ continue
+ line = comments[1]
+ in_comment = False
+
+ # Remove inline block comments
+ line = re.sub(r"/\*.*\*/", " ", line)
+
+ # Does this line start a comment?
+ comments = line.split("/*", 1)
+ if len(comments) > 1:
+ line = comments[0]
+ in_comment = True
+
+ # Deal with line comments
+ line = line.split("--", 1)[0]
+ line = line.split("//", 1)[0]
+
+ # Find *all* semicolons. We need to treat first and last entry
+ # specially.
+ statements = line.split(";")
+
+ # We must prepend statement_buffer to the first statement
+ first_statement = "%s %s" % (
+ statement_buffer.strip(),
+ statements[0].strip()
+ )
+ statements[0] = first_statement
+
+ # Every entry, except the last, is a full statement
+ for statement in statements[:-1]:
+ yield statement.strip()
+
+ # The last entry did *not* end in a semicolon, so we store it for the
+ # next semicolon we find
+ statement_buffer = statements[-1].strip()
+
+
+def executescript(txn, schema_path):
+ with open(schema_path, 'r') as f:
+ for statement in get_statements(f):
+ txn.execute(statement)
+
+
+def _get_or_create_schema_state(txn, database_engine):
+ # Bluntly try creating the schema_version tables.
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
- create_schema = read_schema(schema_path)
- txn.executescript(create_schema)
+ executescript(txn, schema_path)
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
@@ -814,10 +449,13 @@ def _get_or_create_schema_state(txn):
if current_version:
txn.execute(
- "SELECT file FROM applied_schema_deltas WHERE version >= ?",
+ database_engine.convert_param_style(
+ "SELECT file FROM applied_schema_deltas WHERE version >= ?"
+ ),
(current_version,)
)
- return current_version, txn.fetchall(), upgraded
+ applied_deltas = [d for d, in txn.fetchall()]
+ return current_version, applied_deltas, upgraded
return None
@@ -849,7 +487,19 @@ def prepare_sqlite3_database(db_conn):
if row and row[0]:
db_conn.execute(
- "INSERT OR REPLACE INTO schema_version (version, upgraded)"
+ "REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(row[0], False)
)
+
+
+def are_all_users_on_domain(txn, database_engine, domain):
+ sql = database_engine.convert_param_style(
+ "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+ )
+ pat = "%:" + domain
+ txn.execute(sql, (pat,))
+ num_not_matching = txn.fetchall()[0][0]
+ if num_not_matching == 0:
+ return True
+ return False
|