diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 4e9291fdff..d8f351a675 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -16,14 +16,7 @@
from twisted.internet import defer
from synapse.api.events.room import (
- RoomMemberEvent, RoomTopicEvent, FeedbackEvent,
-# RoomConfigEvent,
- RoomNameEvent,
- RoomJoinRulesEvent,
- RoomPowerLevelsEvent,
- RoomAddStateLevelEvent,
- RoomSendEventLevelEvent,
- RoomOpsPowerLevelsEvent,
+ RoomMemberEvent, RoomTopicEvent, FeedbackEvent, RoomNameEvent,
RoomRedactionEvent,
)
@@ -37,9 +30,17 @@ from .registration import RegistrationStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .stream import StreamStore
-from .pdu import StatePduStore, PduStore, PdusTable
from .transactions import TransactionStore
from .keys import KeyStore
+from .event_federation import EventFederationStore
+
+from .state import StateStore
+from .signatures import SignatureStore
+
+from syutil.base64util import decode_base64
+
+from synapse.crypto.event_signing import compute_event_reference_hash
+
import json
import logging
@@ -51,7 +52,6 @@ logger = logging.getLogger(__name__)
SCHEMAS = [
"transactions",
- "pdu",
"users",
"profiles",
"presence",
@@ -59,6 +59,9 @@ SCHEMAS = [
"room_aliases",
"keys",
"redactions",
+ "state",
+ "event_edges",
+ "event_signatures",
]
@@ -73,10 +76,12 @@ class _RollbackButIsFineException(Exception):
"""
pass
+
class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
- PresenceStore, PduStore, StatePduStore, TransactionStore,
- DirectoryStore, KeyStore):
+ PresenceStore, TransactionStore,
+ DirectoryStore, KeyStore, StateStore, SignatureStore,
+ EventFederationStore, ):
def __init__(self, hs):
super(DataStore, self).__init__(hs)
@@ -88,8 +93,7 @@ class DataStore(RoomMemberStore, RoomStore,
@defer.inlineCallbacks
@log_function
- def persist_event(self, event=None, backfilled=False, pdu=None,
- is_new_state=True):
+ def persist_event(self, event, backfilled=False, is_new_state=True):
stream_ordering = None
if backfilled:
if not self.min_token_deferred.called:
@@ -99,8 +103,8 @@ class DataStore(RoomMemberStore, RoomStore,
try:
yield self.runInteraction(
- self._persist_pdu_event_txn,
- pdu=pdu,
+ "persist_event",
+ self._persist_event_txn,
event=event,
backfilled=backfilled,
stream_ordering=stream_ordering,
@@ -119,7 +123,8 @@ class DataStore(RoomMemberStore, RoomStore,
"type",
"room_id",
"content",
- "unrecognized_keys"
+ "unrecognized_keys",
+ "depth",
],
allow_none=allow_none,
)
@@ -130,42 +135,6 @@ class DataStore(RoomMemberStore, RoomStore,
event = self._parse_event_from_row(events_dict)
defer.returnValue(event)
- def _persist_pdu_event_txn(self, txn, pdu=None, event=None,
- backfilled=False, stream_ordering=None,
- is_new_state=True):
- if pdu is not None:
- self._persist_event_pdu_txn(txn, pdu)
- if event is not None:
- return self._persist_event_txn(
- txn, event, backfilled, stream_ordering,
- is_new_state=is_new_state,
- )
-
- def _persist_event_pdu_txn(self, txn, pdu):
- cols = dict(pdu.__dict__)
- unrec_keys = dict(pdu.unrecognized_keys)
- del cols["content"]
- del cols["prev_pdus"]
- cols["content_json"] = json.dumps(pdu.content)
-
- unrec_keys.update({
- k: v for k, v in cols.items()
- if k not in PdusTable.fields
- })
-
- cols["unrecognized_keys"] = json.dumps(unrec_keys)
-
- cols["ts"] = cols.pop("origin_server_ts")
-
- logger.debug("Persisting: %s", repr(cols))
-
- if pdu.is_state:
- self._persist_state_txn(txn, pdu.prev_pdus, cols)
- else:
- self._persist_pdu_txn(txn, pdu.prev_pdus, cols)
-
- self._update_min_depth_for_context_txn(txn, pdu.context, pdu.depth)
-
@log_function
def _persist_event_txn(self, txn, event, backfilled, stream_ordering=None,
is_new_state=True):
@@ -177,19 +146,13 @@ class DataStore(RoomMemberStore, RoomStore,
self._store_room_name_txn(txn, event)
elif event.type == RoomTopicEvent.TYPE:
self._store_room_topic_txn(txn, event)
- elif event.type == RoomJoinRulesEvent.TYPE:
- self._store_join_rule(txn, event)
- elif event.type == RoomPowerLevelsEvent.TYPE:
- self._store_power_levels(txn, event)
- elif event.type == RoomAddStateLevelEvent.TYPE:
- self._store_add_state_level(txn, event)
- elif event.type == RoomSendEventLevelEvent.TYPE:
- self._store_send_event_level(txn, event)
- elif event.type == RoomOpsPowerLevelsEvent.TYPE:
- self._store_ops_level(txn, event)
elif event.type == RoomRedactionEvent.TYPE:
self._store_redaction(txn, event)
+ outlier = False
+ if hasattr(event, "outlier"):
+ outlier = event.outlier
+
vals = {
"topological_ordering": event.depth,
"event_id": event.event_id,
@@ -197,25 +160,34 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id,
"content": json.dumps(event.content),
"processed": True,
+ "outlier": outlier,
+ "depth": event.depth,
}
if stream_ordering is not None:
vals["stream_ordering"] = stream_ordering
- if hasattr(event, "outlier"):
- vals["outlier"] = event.outlier
- else:
- vals["outlier"] = False
-
unrec = {
k: v
for k, v in event.get_full_dict().items()
- if k not in vals.keys() and k not in ["redacted", "redacted_because"]
+ if k not in vals.keys() and k not in [
+ "redacted",
+ "redacted_because",
+ "signatures",
+ "hashes",
+ "prev_events",
+ ]
}
vals["unrecognized_keys"] = json.dumps(unrec)
try:
- self._simple_insert_txn(txn, "events", vals)
+ self._simple_insert_txn(
+ txn,
+ "events",
+ vals,
+ or_replace=(not outlier),
+ or_ignore=bool(outlier),
+ )
except:
logger.warn(
"Failed to persist, probably duplicate: %s",
@@ -224,6 +196,16 @@ class DataStore(RoomMemberStore, RoomStore,
)
raise _RollbackButIsFineException("_persist_event")
+ self._handle_prev_events(
+ txn,
+ outlier=outlier,
+ event_id=event.event_id,
+ prev_events=event.prev_events,
+ room_id=event.room_id,
+ )
+
+ self._store_state_groups_txn(txn, event)
+
is_state = hasattr(event, "state_key") and event.state_key is not None
if is_new_state and is_state:
vals = {
@@ -233,10 +215,15 @@ class DataStore(RoomMemberStore, RoomStore,
"state_key": event.state_key,
}
- if hasattr(event, "prev_state"):
- vals["prev_state"] = event.prev_state
+ if hasattr(event, "replaces_state"):
+ vals["prev_state"] = event.replaces_state
- self._simple_insert_txn(txn, "state_events", vals)
+ self._simple_insert_txn(
+ txn,
+ "state_events",
+ vals,
+ or_replace=True,
+ )
self._simple_insert_txn(
txn,
@@ -246,9 +233,87 @@ class DataStore(RoomMemberStore, RoomStore,
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
- }
+ },
+ or_replace=True,
+ )
+
+ for e_id, h in event.prev_state:
+ self._simple_insert_txn(
+ txn,
+ table="event_edges",
+ values={
+ "event_id": event.event_id,
+ "prev_event_id": e_id,
+ "room_id": event.room_id,
+ "is_state": 1,
+ },
+ or_ignore=True,
+ )
+
+ if not backfilled:
+ self._simple_insert_txn(
+ txn,
+ table="state_forward_extremities",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "state_key": event.state_key,
+ },
+ or_replace=True,
+ )
+
+ for prev_state_id, _ in event.prev_state:
+ self._simple_delete_txn(
+ txn,
+ table="state_forward_extremities",
+ keyvalues={
+ "event_id": prev_state_id,
+ }
+ )
+
+ for hash_alg, hash_base64 in event.hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_event_content_hash_txn(
+ txn, event.event_id, hash_alg, hash_bytes,
)
+ if hasattr(event, "signatures"):
+ logger.debug("sigs: %s", event.signatures)
+ for name, sigs in event.signatures.items():
+ for key_id, signature_base64 in sigs.items():
+ signature_bytes = decode_base64(signature_base64)
+ self._store_event_signature_txn(
+ txn, event.event_id, name, key_id,
+ signature_bytes,
+ )
+
+ for prev_event_id, prev_hashes in event.prev_events:
+ for alg, hash_base64 in prev_hashes.items():
+ hash_bytes = decode_base64(hash_base64)
+ self._store_prev_event_hash_txn(
+ txn, event.event_id, prev_event_id, alg, hash_bytes
+ )
+
+ for auth_id, _ in event.auth_events:
+ self._simple_insert_txn(
+ txn,
+ table="event_auth",
+ values={
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "auth_id": auth_id,
+ },
+ or_ignore=True,
+ )
+
+ (ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
+ self._store_event_reference_hash_txn(
+ txn, event.event_id, ref_alg, ref_hash_bytes
+ )
+
+ self._update_min_depth_for_room_txn(txn, event.room_id, event.depth)
+
def _store_redaction(self, txn, event):
txn.execute(
"INSERT OR IGNORE INTO redactions "
@@ -319,7 +384,7 @@ class DataStore(RoomMemberStore, RoomStore,
],
)
- def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
+ def snapshot_room(self, event):
"""Snapshot the room for an update by a user
Args:
room_id (synapse.types.RoomId): The room to snapshot.
@@ -330,29 +395,33 @@ class DataStore(RoomMemberStore, RoomStore,
synapse.storage.Snapshot: A snapshot of the state of the room.
"""
def _snapshot(txn):
- membership_state = self._get_room_member(txn, user_id, room_id)
- prev_pdus = self._get_latest_pdus_in_context(
- txn, room_id
+ prev_events = self._get_latest_events_in_room(
+ txn,
+ event.room_id
)
- if state_type is not None and state_key is not None:
- prev_state_pdu = self._get_current_state_pdu(
- txn, room_id, state_type, state_key
+
+ prev_state = None
+ state_key = None
+ if hasattr(event, "state_key"):
+ state_key = event.state_key
+ prev_state = self._get_latest_state_in_room(
+ txn,
+ event.room_id,
+ type=event.type,
+ state_key=state_key,
)
- else:
- prev_state_pdu = None
return Snapshot(
store=self,
- room_id=room_id,
- user_id=user_id,
- prev_pdus=prev_pdus,
- membership_state=membership_state,
- state_type=state_type,
+ room_id=event.room_id,
+ user_id=event.user_id,
+ prev_events=prev_events,
+ prev_state=prev_state,
+ state_type=event.type,
state_key=state_key,
- prev_state_pdu=prev_state_pdu,
)
- return self.runInteraction(_snapshot)
+ return self.runInteraction("snapshot_room", _snapshot)
class Snapshot(object):
@@ -361,7 +430,7 @@ class Snapshot(object):
store (DataStore): The datastore.
room_id (RoomId): The room of the snapshot.
user_id (UserId): The user this snapshot is for.
- prev_pdus (list): The list of PDU ids this snapshot is after.
+ prev_events (list): The list of event ids this snapshot is after.
membership_state (RoomMemberEvent): The current state of the user in
the room.
state_type (str, optional): State type captured by the snapshot
@@ -370,32 +439,30 @@ class Snapshot(object):
the previous value of the state type and key in the room.
"""
- def __init__(self, store, room_id, user_id, prev_pdus,
- membership_state, state_type=None, state_key=None,
- prev_state_pdu=None):
+ def __init__(self, store, room_id, user_id, prev_events,
+ prev_state, state_type=None, state_key=None):
self.store = store
self.room_id = room_id
self.user_id = user_id
- self.prev_pdus = prev_pdus
- self.membership_state = membership_state
+ self.prev_events = prev_events
+ self.prev_state = prev_state
self.state_type = state_type
self.state_key = state_key
- self.prev_state_pdu = prev_state_pdu
def fill_out_prev_events(self, event):
- if hasattr(event, "prev_events"):
- return
-
- es = [
- "%s@%s" % (p_id, origin) for p_id, origin, _ in self.prev_pdus
- ]
-
- event.prev_events = [e for e in es if e != event.event_id]
+ if not hasattr(event, "prev_events"):
+ event.prev_events = [
+ (event_id, hashes)
+ for event_id, hashes, _ in self.prev_events
+ ]
+
+ if self.prev_events:
+ event.depth = max([int(v) for _, _, v in self.prev_events]) + 1
+ else:
+ event.depth = 0
- if self.prev_pdus:
- event.depth = max([int(v) for _, _, v in self.prev_pdus]) + 1
- else:
- event.depth = 0
+ if not hasattr(event, "prev_state") and self.prev_state is not None:
+ event.prev_state = self.prev_state
def schema_path(schema):
@@ -436,11 +503,13 @@ def prepare_database(db_conn):
user_version = row[0]
if user_version > SCHEMA_VERSION:
- raise ValueError("Cannot use this database as it is too " +
+ raise ValueError(
+ "Cannot use this database as it is too " +
"new for the server to understand"
)
elif user_version < SCHEMA_VERSION:
- logging.info("Upgrading database from version %d",
+ logging.info(
+ "Upgrading database from version %d",
user_version
)
@@ -452,13 +521,13 @@ def prepare_database(db_conn):
db_conn.commit()
else:
- sql_script = "BEGIN TRANSACTION;"
+ sql_script = "BEGIN TRANSACTION;\n"
for sql_loc in SCHEMAS:
sql_script += read_schema(sql_loc)
+ sql_script += "\n"
sql_script += "COMMIT TRANSACTION;"
c.executescript(sql_script)
db_conn.commit()
c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
c.close()
-
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2faa63904e..30e6eac8db 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,60 +14,72 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.api.errors import StoreError
from synapse.api.events.utils import prune_event
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
+from syutil.base64util import encode_base64
+
+from twisted.internet import defer
import collections
import copy
import json
+import sys
+import time
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging to the .execute() method."""
- __slots__ = ["txn"]
+ __slots__ = ["txn", "name"]
- def __init__(self, txn):
+ def __init__(self, txn, name):
object.__setattr__(self, "txn", txn)
+ object.__setattr__(self, "name", name)
- def __getattribute__(self, name):
- if name == "execute":
- return object.__getattribute__(self, "execute")
-
- return getattr(object.__getattribute__(self, "txn"), name)
+ def __getattr__(self, name):
+ return getattr(self.txn, name)
def __setattr__(self, name, value):
- setattr(object.__getattribute__(self, "txn"), name, value)
+ setattr(self.txn, name, value)
def execute(self, sql, *args, **kwargs):
# TODO(paul): Maybe use 'info' and 'debug' for values?
- sql_logger.debug("[SQL] %s", sql)
+ sql_logger.debug("[SQL] {%s} %s", self.name, sql)
try:
if args and args[0]:
values = args[0]
- sql_logger.debug("[SQL values] " +
- ", ".join(("<%s>",) * len(values)), *values)
+ sql_logger.debug(
+ "[SQL values] {%s} " + ", ".join(("<%s>",) * len(values)),
+ self.name,
+ *values
+ )
except:
# Don't let logging failures stop SQL from working
pass
- # TODO(paul): Here would be an excellent place to put some timing
- # measurements, and log (warning?) slow queries.
- return object.__getattribute__(self, "txn").execute(
- sql, *args, **kwargs
- )
+ start = time.clock() * 1000
+ try:
+ return self.txn.execute(
+ sql, *args, **kwargs
+ )
+ except:
+ logger.exception("[SQL FAIL] {%s}", self.name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
class SQLBaseStore(object):
+ _TXN_ID = 0
def __init__(self, hs):
self.hs = hs
@@ -76,13 +88,34 @@ class SQLBaseStore(object):
self._clock = hs.get_clock()
@defer.inlineCallbacks
- def runInteraction(self, func, *args, **kwargs):
+ def runInteraction(self, desc, func, *args, **kwargs):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
def inner_func(txn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
current_context.copy_to(context)
- return func(LoggingTransaction(txn), *args, **kwargs)
+ start = time.clock() * 1000
+ txn_id = SQLBaseStore._TXN_ID
+
+ # We don't really need these to be unique, so lets stop it from
+ # growing really large.
+ self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+
+ name = "%s-%x" % (desc, txn_id, )
+
+ transaction_logger.debug("[TXN START] {%s}", name)
+ try:
+ return func(LoggingTransaction(txn, name), *args, **kwargs)
+ except:
+ logger.exception("[TXN FAIL] {%s}", name)
+ raise
+ finally:
+ end = time.clock() * 1000
+ transaction_logger.debug(
+ "[TXN END] {%s} %f",
+ name, end - start
+ )
+
with PreserveLoggingContext():
result = yield self._db_pool.runInteraction(
inner_func, *args, **kwargs
@@ -121,7 +154,7 @@ class SQLBaseStore(object):
else:
return cursor.fetchall()
- return self.runInteraction(interaction)
+ return self.runInteraction("_execute", interaction)
def _execute_and_decode(self, query, *args):
return self._execute(self.cursor_to_dict, query, *args)
@@ -138,6 +171,7 @@ class SQLBaseStore(object):
or_replace : bool; if True performs an INSERT OR REPLACE
"""
return self.runInteraction(
+ "_simple_insert",
self._simple_insert_txn, table, values, or_replace=or_replace,
or_ignore=or_ignore,
)
@@ -178,7 +212,6 @@ class SQLBaseStore(object):
table, keyvalues, retcols=retcols, allow_none=allow_none
)
- @defer.inlineCallbacks
def _simple_select_one_onecol(self, table, keyvalues, retcol,
allow_none=False):
"""Executes a SELECT query on the named table, which is expected to
@@ -189,19 +222,40 @@ class SQLBaseStore(object):
keyvalues : dict of column names and values to select the row with
retcol : string giving the name of the column to return
"""
- ret = yield self._simple_select_one(
+ return self.runInteraction(
+ "_simple_select_one_onecol",
+ self._simple_select_one_onecol_txn,
+ table, keyvalues, retcol, allow_none=allow_none,
+ )
+
+ def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+ allow_none=False):
+ ret = self._simple_select_onecol_txn(
+ txn,
table=table,
keyvalues=keyvalues,
- retcols=[retcol],
- allow_none=allow_none
+ retcol=retcol,
)
if ret:
- defer.returnValue(ret[retcol])
+ return ret[0]
else:
- defer.returnValue(None)
+ if allow_none:
+ return None
+ else:
+ raise StoreError(404, "No row found")
+
+ def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+ sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
+ "retcol": retcol,
+ "table": table,
+ "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
+ }
+
+ txn.execute(sql, keyvalues.values())
+
+ return [r[0] for r in txn.fetchall()]
- @defer.inlineCallbacks
def _simple_select_onecol(self, table, keyvalues, retcol):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -214,25 +268,33 @@ class SQLBaseStore(object):
Returns:
Deferred: Results in a list
"""
- sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
- "retcol": retcol,
- "table": table,
- "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
- }
-
- def func(txn):
- txn.execute(sql, keyvalues.values())
- return txn.fetchall()
+ return self.runInteraction(
+ "_simple_select_onecol",
+ self._simple_select_onecol_txn,
+ table, keyvalues, retcol
+ )
- res = yield self.runInteraction(func)
+ def _simple_select_list(self, table, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
- defer.returnValue([r[0] for r in res])
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ return self.runInteraction(
+ "_simple_select_list",
+ self._simple_select_list_txn,
+ table, keyvalues, retcols
+ )
- def _simple_select_list(self, table, keyvalues, retcols):
+ def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
Args:
+ txn : Transaction object
table : string giving the table name
keyvalues : dict of column names and values to select the rows with
retcols : list of strings giving the names of the columns to return
@@ -240,14 +302,11 @@ class SQLBaseStore(object):
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- def func(txn):
- txn.execute(sql, keyvalues.values())
- return self.cursor_to_dict(txn)
-
- return self.runInteraction(func)
+ txn.execute(sql, keyvalues.values())
+ return self.cursor_to_dict(txn)
def _simple_update_one(self, table, keyvalues, updatevalues,
retcols=None):
@@ -315,7 +374,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched")
return ret
- return self.runInteraction(func)
+ return self.runInteraction("_simple_selectupdate_one", func)
def _simple_delete_one(self, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
@@ -327,7 +386,7 @@ class SQLBaseStore(object):
"""
sql = "DELETE FROM %s WHERE %s" % (
table,
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
def func(txn):
@@ -336,7 +395,25 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "more than one row matched")
- return self.runInteraction(func)
+ return self.runInteraction("_simple_delete_one", func)
+
+ def _simple_delete(self, table, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+
+ return self.runInteraction("_simple_delete", self._simple_delete_txn)
+
+ def _simple_delete_txn(self, txn, table, keyvalues):
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+ )
+
+ return txn.execute(sql, keyvalues.values())
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
@@ -354,7 +431,7 @@ class SQLBaseStore(object):
return 0
return max_id
- return self.runInteraction(func)
+ return self.runInteraction("_simple_max_id", func)
def _parse_event_from_row(self, row_dict):
d = copy.deepcopy({k: v for k, v in row_dict.items()})
@@ -363,6 +440,10 @@ class SQLBaseStore(object):
d.pop("topological_ordering", None)
d.pop("processed", None)
d["origin_server_ts"] = d.pop("ts", 0)
+ replaces_state = d.pop("prev_state", None)
+
+ if replaces_state:
+ d["replaces_state"] = replaces_state
d.update(json.loads(row_dict["unrecognized_keys"]))
d["content"] = json.loads(d["content"])
@@ -377,23 +458,68 @@ class SQLBaseStore(object):
**d
)
+ def _get_events_txn(self, txn, event_ids):
+ # FIXME (erikj): This should be batched?
+
+ sql = "SELECT * FROM events WHERE event_id = ?"
+
+ event_rows = []
+ for e_id in event_ids:
+ c = txn.execute(sql, (e_id,))
+ event_rows.extend(self.cursor_to_dict(c))
+
+ return self._parse_events_txn(txn, event_rows)
+
def _parse_events(self, rows):
- return self.runInteraction(self._parse_events_txn, rows)
+ return self.runInteraction(
+ "_parse_events", self._parse_events_txn, rows
+ )
def _parse_events_txn(self, txn, rows):
events = [self._parse_event_from_row(r) for r in rows]
- sql = "SELECT * FROM events WHERE event_id = ?"
+ select_event_sql = "SELECT * FROM events WHERE event_id = ?"
+
+ for i, ev in enumerate(events):
+ signatures = self._get_event_signatures_txn(
+ txn, ev.event_id,
+ )
- for ev in events:
- if hasattr(ev, "prev_state"):
- # Load previous state_content.
- # TODO: Should we be pulling this out above?
- cursor = txn.execute(sql, (ev.prev_state,))
- prevs = self.cursor_to_dict(cursor)
- if prevs:
- prev = self._parse_event_from_row(prevs[0])
- ev.prev_content = prev.content
+ ev.signatures = {
+ n: {
+ k: encode_base64(v) for k, v in s.items()
+ }
+ for n, s in signatures.items()
+ }
+
+ prevs = self._get_prev_events_and_state(txn, ev.event_id)
+
+ ev.prev_events = [
+ (e_id, h)
+ for e_id, h, is_state in prevs
+ if is_state == 0
+ ]
+
+ ev.auth_events = self._get_auth_events(txn, ev.event_id)
+
+ if hasattr(ev, "state_key"):
+ ev.prev_state = [
+ (e_id, h)
+ for e_id, h, is_state in prevs
+ if is_state == 1
+ ]
+
+ if hasattr(ev, "replaces_state"):
+ # Load previous state_content.
+ # FIXME (erikj): Handle multiple prev_states.
+ cursor = txn.execute(
+ select_event_sql,
+ (ev.replaces_state,)
+ )
+ prevs = self.cursor_to_dict(cursor)
+ if prevs:
+ prev = self._parse_event_from_row(prevs[0])
+ ev.prev_content = prev.content
if not hasattr(ev, "redacted"):
logger.debug("Doesn't have redacted key: %s", ev)
@@ -401,15 +527,16 @@ class SQLBaseStore(object):
if ev.redacted:
# Get the redaction event.
- sql = "SELECT * FROM events WHERE event_id = ?"
- txn.execute(sql, (ev.redacted,))
+ select_event_sql = "SELECT * FROM events WHERE event_id = ?"
+ txn.execute(select_event_sql, (ev.redacted,))
del_evs = self._parse_events_txn(
txn, self.cursor_to_dict(txn)
)
if del_evs:
- prune_event(ev)
+ ev = prune_event(ev)
+ events[i] = ev
ev.redacted_because = del_evs[0]
return events
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 52373a28a6..d6a7113b9c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -95,6 +95,7 @@ class DirectoryStore(SQLBaseStore):
def delete_room_alias(self, room_alias):
return self.runInteraction(
+ "delete_room_alias",
self._delete_room_alias_txn,
room_alias,
)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
new file mode 100644
index 0000000000..6c559f8f63
--- /dev/null
+++ b/synapse/storage/event_federation.py
@@ -0,0 +1,386 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from syutil.base64util import encode_base64
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class EventFederationStore(SQLBaseStore):
+ """ Responsible for storing and serving up the various graphs associated
+ with an event. Including the main event graph and the auth chains for an
+ event.
+
+ Also has methods for getting the front (latest) and back (oldest) edges
+ of the event graphs. These are used to generate the parents for new events
+ and backfilling from another server respectively.
+ """
+
+ def get_auth_chain(self, event_id):
+ return self.runInteraction(
+ "get_auth_chain",
+ self._get_auth_chain_txn,
+ event_id
+ )
+
+ def _get_auth_chain_txn(self, txn, event_id):
+ results = self._get_auth_chain_ids_txn(txn, event_id)
+
+ sql = "SELECT * FROM events WHERE event_id = ?"
+ rows = []
+ for ev_id in results:
+ c = txn.execute(sql, (ev_id,))
+ rows.extend(self.cursor_to_dict(c))
+
+ return self._parse_events_txn(txn, rows)
+
+ def get_auth_chain_ids(self, event_id):
+ return self.runInteraction(
+ "get_auth_chain_ids",
+ self._get_auth_chain_ids_txn,
+ event_id
+ )
+
+ def _get_auth_chain_ids_txn(self, txn, event_id):
+ results = set()
+
+ base_sql = (
+ "SELECT auth_id FROM event_auth WHERE %s"
+ )
+
+ front = set([event_id])
+ while front:
+ sql = base_sql % (
+ " OR ".join(["event_id=?"] * len(front)),
+ )
+
+ txn.execute(sql, list(front))
+ front = [r[0] for r in txn.fetchall()]
+ results.update(front)
+
+ return list(results)
+
+ def get_oldest_events_in_room(self, room_id):
+ return self.runInteraction(
+ "get_oldest_events_in_room",
+ self._get_oldest_events_in_room_txn,
+ room_id,
+ )
+
+ def _get_oldest_events_in_room_txn(self, txn, room_id):
+ return self._simple_select_onecol_txn(
+ txn,
+ table="event_backward_extremities",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="event_id",
+ )
+
+ def get_latest_events_in_room(self, room_id):
+ return self.runInteraction(
+ "get_latest_events_in_room",
+ self._get_latest_events_in_room,
+ room_id,
+ )
+
+ def _get_latest_events_in_room(self, txn, room_id):
+ sql = (
+ "SELECT e.event_id, e.depth FROM events as e "
+ "INNER JOIN event_forward_extremities as f "
+ "ON e.event_id = f.event_id "
+ "WHERE f.room_id = ?"
+ )
+
+ txn.execute(sql, (room_id, ))
+
+ results = []
+ for event_id, depth in txn.fetchall():
+ hashes = self._get_event_reference_hashes_txn(txn, event_id)
+ prev_hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ if k == "sha256"
+ }
+ results.append((event_id, prev_hashes, depth))
+
+ return results
+
+ def _get_latest_state_in_room(self, txn, room_id, type, state_key):
+ event_ids = self._simple_select_onecol_txn(
+ txn,
+ table="state_forward_extremities",
+ keyvalues={
+ "room_id": room_id,
+ "type": type,
+ "state_key": state_key,
+ },
+ retcol="event_id",
+ )
+
+ results = []
+ for event_id in event_ids:
+ hashes = self._get_event_reference_hashes_txn(txn, event_id)
+ prev_hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ if k == "sha256"
+ }
+ results.append((event_id, prev_hashes))
+
+ return results
+
+ def _get_prev_events(self, txn, event_id):
+ results = self._get_prev_events_and_state(
+ txn,
+ event_id,
+ is_state=0,
+ )
+
+ return [(e_id, h, ) for e_id, h, _ in results]
+
+ def _get_prev_state(self, txn, event_id):
+ results = self._get_prev_events_and_state(
+ txn,
+ event_id,
+ is_state=1,
+ )
+
+ return [(e_id, h, ) for e_id, h, _ in results]
+
+ def _get_prev_events_and_state(self, txn, event_id, is_state=None):
+ keyvalues = {
+ "event_id": event_id,
+ }
+
+ if is_state is not None:
+ keyvalues["is_state"] = is_state
+
+ res = self._simple_select_list_txn(
+ txn,
+ table="event_edges",
+ keyvalues=keyvalues,
+ retcols=["prev_event_id", "is_state"],
+ )
+
+ results = []
+ for d in res:
+ hashes = self._get_event_reference_hashes_txn(
+ txn,
+ d["prev_event_id"]
+ )
+ prev_hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ if k == "sha256"
+ }
+ results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
+
+ return results
+
+ def _get_auth_events(self, txn, event_id):
+ auth_ids = self._simple_select_onecol_txn(
+ txn,
+ table="event_auth",
+ keyvalues={
+ "event_id": event_id,
+ },
+ retcol="auth_id",
+ )
+
+ results = []
+ for auth_id in auth_ids:
+ hashes = self._get_event_reference_hashes_txn(txn, auth_id)
+ prev_hashes = {
+ k: encode_base64(v) for k, v in hashes.items()
+ if k == "sha256"
+ }
+ results.append((auth_id, prev_hashes))
+
+ return results
+
+ def get_min_depth(self, room_id):
+ """ For hte given room, get the minimum depth we have seen for it.
+ """
+ return self.runInteraction(
+ "get_min_depth",
+ self._get_min_depth_interaction,
+ room_id,
+ )
+
+ def _get_min_depth_interaction(self, txn, room_id):
+ min_depth = self._simple_select_one_onecol_txn(
+ txn,
+ table="room_depth",
+ keyvalues={"room_id": room_id},
+ retcol="min_depth",
+ allow_none=True,
+ )
+
+ return int(min_depth) if min_depth is not None else None
+
+ def _update_min_depth_for_room_txn(self, txn, room_id, depth):
+ min_depth = self._get_min_depth_interaction(txn, room_id)
+
+ do_insert = depth < min_depth if min_depth else True
+
+ if do_insert:
+ self._simple_insert_txn(
+ txn,
+ table="room_depth",
+ values={
+ "room_id": room_id,
+ "min_depth": depth,
+ },
+ or_replace=True,
+ )
+
+ def _handle_prev_events(self, txn, outlier, event_id, prev_events,
+ room_id):
+ """
+ For the given event, update the event edges table and forward and
+ backward extremities tables.
+ """
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_insert_txn(
+ txn,
+ table="event_edges",
+ values={
+ "event_id": event_id,
+ "prev_event_id": e_id,
+ "room_id": room_id,
+ "is_state": 0,
+ },
+ or_ignore=True,
+ )
+
+ # Update the extremities table if this is not an outlier.
+ if not outlier:
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_delete_txn(
+ txn,
+ table="event_forward_extremities",
+ keyvalues={
+ "event_id": e_id,
+ "room_id": room_id,
+ }
+ )
+
+ # We only insert as a forward extremity the new event if there are
+ # no other events that reference it as a prev event
+ query = (
+ "INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
+ "SELECT ?, ? WHERE NOT EXISTS ("
+ "SELECT 1 FROM %(event_edges)s WHERE "
+ "prev_event_id = ? "
+ ")"
+ ) % {
+ "table": "event_forward_extremities",
+ "event_edges": "event_edges",
+ }
+
+ logger.debug("query: %s", query)
+
+ txn.execute(query, (event_id, room_id, event_id))
+
+ # Insert all the prev_events as a backwards thing, they'll get
+ # deleted in a second if they're incorrect anyway.
+ for e_id, _ in prev_events:
+ # TODO (erikj): This could be done as a bulk insert
+ self._simple_insert_txn(
+ txn,
+ table="event_backward_extremities",
+ values={
+ "event_id": e_id,
+ "room_id": room_id,
+ },
+ or_ignore=True,
+ )
+
+ # Also delete from the backwards extremities table all ones that
+ # reference events that we have already seen
+ query = (
+ "DELETE FROM event_backward_extremities WHERE EXISTS ("
+ "SELECT 1 FROM events "
+ "WHERE "
+ "event_backward_extremities.event_id = events.event_id "
+ "AND not events.outlier "
+ ")"
+ )
+ txn.execute(query)
+
+ def get_backfill_events(self, room_id, event_list, limit):
+ """Get a list of Events for a given topic that occurred before (and
+ including) the events in event_list. Return a list of max size `limit`
+
+ Args:
+ txn
+ room_id (str)
+ event_list (list)
+ limit (int)
+ """
+ return self.runInteraction(
+ "get_backfill_events",
+ self._get_backfill_events, room_id, event_list, limit
+ )
+
+ def _get_backfill_events(self, txn, room_id, event_list, limit):
+ logger.debug(
+ "_get_backfill_events: %s, %s, %s",
+ room_id, repr(event_list), limit
+ )
+
+ event_results = event_list
+
+ front = event_list
+
+ query = (
+ "SELECT prev_event_id FROM event_edges "
+ "WHERE room_id = ? AND event_id = ? "
+ "LIMIT ?"
+ )
+
+ # We iterate through all event_ids in `front` to select their previous
+ # events. These are dumped in `new_front`.
+ # We continue until we reach the limit *or* new_front is empty (i.e.,
+ # we've run out of things to select
+ while front and len(event_results) < limit:
+
+ new_front = []
+ for event_id in front:
+ logger.debug(
+ "_backfill_interaction: id=%s",
+ event_id
+ )
+
+ txn.execute(
+ query,
+ (room_id, event_id, limit - len(event_results))
+ )
+
+ for row in txn.fetchall():
+ logger.debug(
+ "_backfill_interaction: got id=%s",
+ *row
+ )
+ new_front.append(row[0])
+
+ front = new_front
+ event_results += new_front
+
+ return self._get_events_txn(txn, event_results)
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
deleted file mode 100644
index d70467dcd6..0000000000
--- a/synapse/storage/pdu.py
+++ /dev/null
@@ -1,915 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from twisted.internet import defer
-
-from ._base import SQLBaseStore, Table, JoinHelper
-
-from synapse.federation.units import Pdu
-from synapse.util.logutils import log_function
-
-from collections import namedtuple
-
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-class PduStore(SQLBaseStore):
- """A collection of queries for handling PDUs.
- """
-
- def get_pdu(self, pdu_id, origin):
- """Given a pdu_id and origin, get a PDU.
-
- Args:
- txn
- pdu_id (str)
- origin (str)
-
- Returns:
- PduTuple: If the pdu does not exist in the database, returns None
- """
-
- return self.runInteraction(
- self._get_pdu_tuple, pdu_id, origin
- )
-
- def _get_pdu_tuple(self, txn, pdu_id, origin):
- res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
- return res[0] if res else None
-
- def _get_pdu_tuples(self, txn, pdu_id_tuples):
- results = []
- for pdu_id, origin in pdu_id_tuples:
- txn.execute(
- PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
- (pdu_id, origin)
- )
-
- edges = [
- (r.prev_pdu_id, r.prev_origin)
- for r in PduEdgesTable.decode_results(txn.fetchall())
- ]
-
- query = (
- "SELECT %(fields)s FROM %(pdus)s as p "
- "LEFT JOIN %(state)s as s "
- "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
- "WHERE p.pdu_id = ? AND p.origin = ? "
- ) % {
- "fields": _pdu_state_joiner.get_fields(
- PdusTable="p", StatePdusTable="s"),
- "pdus": PdusTable.table_name,
- "state": StatePdusTable.table_name,
- }
-
- txn.execute(query, (pdu_id, origin))
-
- row = txn.fetchone()
- if row:
- results.append(PduTuple(PduEntry(*row), edges))
-
- return results
-
- def get_current_state_for_context(self, context):
- """Get a list of PDUs that represent the current state for a given
- context
-
- Args:
- context (str)
-
- Returns:
- list: A list of PduTuples
- """
-
- return self.runInteraction(
- self._get_current_state_for_context,
- context
- )
-
- def _get_current_state_for_context(self, txn, context):
- query = (
- "SELECT pdu_id, origin FROM %s WHERE context = ?"
- % CurrentStateTable.table_name
- )
-
- logger.debug("get_current_state %s, Args=%s", query, context)
- txn.execute(query, (context,))
-
- res = txn.fetchall()
-
- logger.debug("get_current_state %d results", len(res))
-
- return self._get_pdu_tuples(txn, res)
-
- def _persist_pdu_txn(self, txn, prev_pdus, cols):
- """Inserts a (non-state) PDU into the database.
-
- Args:
- txn,
- prev_pdus (list)
- **cols: The columns to insert into the PdusTable.
- """
- entry = PdusTable.EntryType(
- **{k: cols.get(k, None) for k in PdusTable.fields}
- )
-
- txn.execute(PdusTable.insert_statement(), entry)
-
- self._handle_prev_pdus(
- txn, entry.outlier, entry.pdu_id, entry.origin,
- prev_pdus, entry.context
- )
-
- def mark_pdu_as_processed(self, pdu_id, pdu_origin):
- """Mark a received PDU as processed.
-
- Args:
- txn
- pdu_id (str)
- pdu_origin (str)
- """
-
- return self.runInteraction(
- self._mark_as_processed, pdu_id, pdu_origin
- )
-
- def _mark_as_processed(self, txn, pdu_id, pdu_origin):
- txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
-
- def get_all_pdus_from_context(self, context):
- """Get a list of all PDUs for a given context."""
- return self.runInteraction(
- self._get_all_pdus_from_context, context,
- )
-
- def _get_all_pdus_from_context(self, txn, context):
- query = (
- "SELECT pdu_id, origin FROM %s "
- "WHERE context = ?"
- ) % PdusTable.table_name
-
- txn.execute(query, (context,))
-
- return self._get_pdu_tuples(txn, txn.fetchall())
-
- def get_backfill(self, context, pdu_list, limit):
- """Get a list of Pdus for a given topic that occured before (and
- including) the pdus in pdu_list. Return a list of max size `limit`.
-
- Args:
- txn
- context (str)
- pdu_list (list)
- limit (int)
-
- Return:
- list: A list of PduTuples
- """
- return self.runInteraction(
- self._get_backfill, context, pdu_list, limit
- )
-
- def _get_backfill(self, txn, context, pdu_list, limit):
- logger.debug(
- "backfill: %s, %s, %s",
- context, repr(pdu_list), limit
- )
-
- # We seed the pdu_results with the things from the pdu_list.
- pdu_results = pdu_list
-
- front = pdu_list
-
- query = (
- "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
- "WHERE context = ? AND pdu_id = ? AND origin = ? "
- "LIMIT ?"
- ) % {
- "edges_table": PduEdgesTable.table_name,
- }
-
- # We iterate through all pdu_ids in `front` to select their previous
- # pdus. These are dumped in `new_front`. We continue until we reach the
- # limit *or* new_front is empty (i.e., we've run out of things to
- # select
- while front and len(pdu_results) < limit:
-
- new_front = []
- for pdu_id, origin in front:
- logger.debug(
- "_backfill_interaction: i=%s, o=%s",
- pdu_id, origin
- )
-
- txn.execute(
- query,
- (context, pdu_id, origin, limit - len(pdu_results))
- )
-
- for row in txn.fetchall():
- logger.debug(
- "_backfill_interaction: got i=%s, o=%s",
- *row
- )
- new_front.append(row)
-
- front = new_front
- pdu_results += new_front
-
- # We also want to update the `prev_pdus` attributes before returning.
- return self._get_pdu_tuples(txn, pdu_results)
-
- def get_min_depth_for_context(self, context):
- """Get the current minimum depth for a context
-
- Args:
- txn
- context (str)
- """
- return self.runInteraction(
- self._get_min_depth_for_context, context
- )
-
- def _get_min_depth_for_context(self, txn, context):
- return self._get_min_depth_interaction(txn, context)
-
- def _get_min_depth_interaction(self, txn, context):
- txn.execute(
- "SELECT min_depth FROM %s WHERE context = ?"
- % ContextDepthTable.table_name,
- (context,)
- )
-
- row = txn.fetchone()
-
- return row[0] if row else None
-
- def _update_min_depth_for_context_txn(self, txn, context, depth):
- """Update the minimum `depth` of the given context, which is the line
- on which we stop backfilling backwards.
-
- Args:
- context (str)
- depth (int)
- """
- min_depth = self._get_min_depth_interaction(txn, context)
-
- do_insert = depth < min_depth if min_depth else True
-
- if do_insert:
- txn.execute(
- "INSERT OR REPLACE INTO %s (context, min_depth) "
- "VALUES (?,?)" % ContextDepthTable.table_name,
- (context, depth)
- )
-
- def _get_latest_pdus_in_context(self, txn, context):
- """Get's a list of the most current pdus for a given context. This is
- used when we are sending a Pdu and need to fill out the `prev_pdus`
- key
-
- Args:
- txn
- context
- """
- query = (
- "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
- "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
- "AND f.origin = p.origin "
- "WHERE f.context = ?"
- ) % {
- "pdus": PdusTable.table_name,
- "forward": PduForwardExtremitiesTable.table_name,
- }
-
- logger.debug("get_prev query: %s", query)
-
- txn.execute(
- query,
- (context, )
- )
-
- results = txn.fetchall()
-
- return [(row[0], row[1], row[2]) for row in results]
-
- @defer.inlineCallbacks
- def get_oldest_pdus_in_context(self, context):
- """Get a list of Pdus that we haven't backfilled beyond yet (and havent
- seen). This list is used when we want to backfill backwards and is the
- list we send to the remote server.
-
- Args:
- txn
- context (str)
-
- Returns:
- list: A list of PduIdTuple.
- """
- results = yield self._execute(
- None,
- "SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
- % {"back": PduBackwardExtremitiesTable.table_name, },
- context
- )
-
- defer.returnValue([PduIdTuple(i, o) for i, o in results])
-
- def is_pdu_new(self, pdu_id, origin, context, depth):
- """For a given Pdu, try and figure out if it's 'new', i.e., if it's
- not something we got randomly from the past, for example when we
- request the current state of the room that will probably return a bunch
- of pdus from before we joined.
-
- Args:
- txn
- pdu_id (str)
- origin (str)
- context (str)
- depth (int)
-
- Returns:
- bool
- """
-
- return self.runInteraction(
- self._is_pdu_new,
- pdu_id=pdu_id,
- origin=origin,
- context=context,
- depth=depth
- )
-
- def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
- # If depth > min depth in back table, then we classify it as new.
- # OR if there is nothing in the back table, then it kinda needs to
- # be a new thing.
- query = (
- "SELECT min(p.depth) FROM %(edges)s as e "
- "INNER JOIN %(back)s as b "
- "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
- "INNER JOIN %(pdus)s as p "
- "ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
- "WHERE p.context = ?"
- ) % {
- "pdus": PdusTable.table_name,
- "edges": PduEdgesTable.table_name,
- "back": PduBackwardExtremitiesTable.table_name,
- }
-
- txn.execute(query, (context,))
-
- min_depth, = txn.fetchone()
-
- if not min_depth or depth > int(min_depth):
- logger.debug(
- "is_new true: id=%s, o=%s, d=%s min_depth=%s",
- pdu_id, origin, depth, min_depth
- )
- return True
-
- # If this pdu is in the forwards table, then it also is a new one
- query = (
- "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
- ) % {
- "forward": PduForwardExtremitiesTable.table_name,
- }
-
- txn.execute(query, (pdu_id, origin))
-
- # Did we get anything?
- if txn.fetchall():
- logger.debug(
- "is_new true: id=%s, o=%s, d=%s was forward",
- pdu_id, origin, depth
- )
- return True
-
- logger.debug(
- "is_new false: id=%s, o=%s, d=%s",
- pdu_id, origin, depth
- )
-
- # FINE THEN. It's probably old.
- return False
-
- @staticmethod
- @log_function
- def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
- context):
- txn.executemany(
- PduEdgesTable.insert_statement(),
- [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
- )
-
- # Update the extremities table if this is not an outlier.
- if not outlier:
-
- # First, we delete the new one from the forwards extremities table.
- query = (
- "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
- % PduForwardExtremitiesTable.table_name
- )
- txn.executemany(query, prev_pdus)
-
- # We only insert as a forward extremety the new pdu if there are no
- # other pdus that reference it as a prev pdu
- query = (
- "INSERT INTO %(table)s (pdu_id, origin, context) "
- "SELECT ?, ?, ? WHERE NOT EXISTS ("
- "SELECT 1 FROM %(pdu_edges)s WHERE "
- "prev_pdu_id = ? AND prev_origin = ?"
- ")"
- ) % {
- "table": PduForwardExtremitiesTable.table_name,
- "pdu_edges": PduEdgesTable.table_name
- }
-
- logger.debug("query: %s", query)
-
- txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
-
- # Insert all the prev_pdus as a backwards thing, they'll get
- # deleted in a second if they're incorrect anyway.
- txn.executemany(
- PduBackwardExtremitiesTable.insert_statement(),
- [(i, o, context) for i, o in prev_pdus]
- )
-
- # Also delete from the backwards extremities table all ones that
- # reference pdus that we have already seen
- query = (
- "DELETE FROM %(pdu_back)s WHERE EXISTS ("
- "SELECT 1 FROM %(pdus)s AS pdus "
- "WHERE "
- "%(pdu_back)s.pdu_id = pdus.pdu_id "
- "AND %(pdu_back)s.origin = pdus.origin "
- "AND not pdus.outlier "
- ")"
- ) % {
- "pdu_back": PduBackwardExtremitiesTable.table_name,
- "pdus": PdusTable.table_name,
- }
- txn.execute(query)
-
-
-class StatePduStore(SQLBaseStore):
- """A collection of queries for handling state PDUs.
- """
-
- def _persist_state_txn(self, txn, prev_pdus, cols):
- """Inserts a state PDU into the database
-
- Args:
- txn,
- prev_pdus (list)
- **cols: The columns to insert into the PdusTable and StatePdusTable
- """
- pdu_entry = PdusTable.EntryType(
- **{k: cols.get(k, None) for k in PdusTable.fields}
- )
- state_entry = StatePdusTable.EntryType(
- **{k: cols.get(k, None) for k in StatePdusTable.fields}
- )
-
- logger.debug("Inserting pdu: %s", repr(pdu_entry))
- logger.debug("Inserting state: %s", repr(state_entry))
-
- txn.execute(PdusTable.insert_statement(), pdu_entry)
- txn.execute(StatePdusTable.insert_statement(), state_entry)
-
- self._handle_prev_pdus(
- txn,
- pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
- pdu_entry.context
- )
-
- def get_unresolved_state_tree(self, new_state_pdu):
- return self.runInteraction(
- self._get_unresolved_state_tree, new_state_pdu
- )
-
- @log_function
- def _get_unresolved_state_tree(self, txn, new_pdu):
- current = self._get_current_interaction(
- txn,
- new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
- )
-
- ReturnType = namedtuple(
- "StateReturnType", ["new_branch", "current_branch"]
- )
- return_value = ReturnType([new_pdu], [])
-
- if not current:
- logger.debug("get_unresolved_state_tree No current state.")
- return (return_value, None)
-
- return_value.current_branch.append(current)
-
- enum_branches = self._enumerate_state_branches(
- txn, new_pdu, current
- )
-
- missing_branch = None
- for branch, prev_state, state in enum_branches:
- if state:
- return_value[branch].append(state)
- else:
- # We don't have prev_state :(
- missing_branch = branch
- break
-
- return (return_value, missing_branch)
-
- def update_current_state(self, pdu_id, origin, context, pdu_type,
- state_key):
- return self.runInteraction(
- self._update_current_state,
- pdu_id, origin, context, pdu_type, state_key
- )
-
- def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
- state_key):
- query = (
- "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
- ) % {
- "curr": CurrentStateTable.table_name,
- "fields": CurrentStateTable.get_fields_string(),
- "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
- }
-
- query_args = CurrentStateTable.EntryType(
- pdu_id=pdu_id,
- origin=origin,
- context=context,
- pdu_type=pdu_type,
- state_key=state_key
- )
-
- txn.execute(query, query_args)
-
- def get_current_state_pdu(self, context, pdu_type, state_key):
- """For a given context, pdu_type, state_key 3-tuple, return what is
- currently considered the current state.
-
- Args:
- txn
- context (str)
- pdu_type (str)
- state_key (str)
-
- Returns:
- PduEntry
- """
-
- return self.runInteraction(
- self._get_current_state_pdu, context, pdu_type, state_key
- )
-
- def _get_current_state_pdu(self, txn, context, pdu_type, state_key):
- return self._get_current_interaction(txn, context, pdu_type, state_key)
-
- def _get_current_interaction(self, txn, context, pdu_type, state_key):
- logger.debug(
- "_get_current_interaction %s %s %s",
- context, pdu_type, state_key
- )
-
- fields = _pdu_state_joiner.get_fields(
- PdusTable="p", StatePdusTable="s")
-
- current_query = (
- "SELECT %(fields)s FROM %(state)s as s "
- "INNER JOIN %(pdus)s as p "
- "ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
- "INNER JOIN %(curr)s as c "
- "ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
- "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
- ) % {
- "fields": fields,
- "curr": CurrentStateTable.table_name,
- "state": StatePdusTable.table_name,
- "pdus": PdusTable.table_name,
- }
-
- txn.execute(
- current_query,
- (context, pdu_type, state_key)
- )
-
- row = txn.fetchone()
-
- result = PduEntry(*row) if row else None
-
- if not result:
- logger.debug("_get_current_interaction not found")
- else:
- logger.debug(
- "_get_current_interaction found %s %s",
- result.pdu_id, result.origin
- )
-
- return result
-
- def handle_new_state(self, new_pdu):
- """Actually perform conflict resolution on the new_pdu on the
- assumption we have all the pdus required to perform it.
-
- Args:
- new_pdu
-
- Returns:
- bool: True if the new_pdu clobbered the current state, False if not
- """
- return self.runInteraction(
- self._handle_new_state, new_pdu
- )
-
- def _handle_new_state(self, txn, new_pdu):
- logger.debug(
- "handle_new_state %s %s",
- new_pdu.pdu_id, new_pdu.origin
- )
-
- current = self._get_current_interaction(
- txn,
- new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
- )
-
- is_current = False
-
- if (not current or not current.prev_state_id
- or not current.prev_state_origin):
- # Oh, we don't have any state for this yet.
- is_current = True
- elif (current.pdu_id == new_pdu.prev_state_id
- and current.origin == new_pdu.prev_state_origin):
- # Oh! A direct clobber. Just do it.
- is_current = True
- else:
- ##
- # Ok, now loop through until we get to a common ancestor.
- max_new = int(new_pdu.power_level)
- max_current = int(current.power_level)
-
- enum_branches = self._enumerate_state_branches(
- txn, new_pdu, current
- )
- for branch, prev_state, state in enum_branches:
- if not state:
- raise RuntimeError(
- "Could not find state_pdu %s %s" %
- (
- prev_state.prev_state_id,
- prev_state.prev_state_origin
- )
- )
-
- if branch == 0:
- max_new = max(int(state.depth), max_new)
- else:
- max_current = max(int(state.depth), max_current)
-
- is_current = max_new > max_current
-
- if is_current:
- logger.debug("handle_new_state make current")
-
- # Right, this is a new thing, so woo, just insert it.
- txn.execute(
- "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
- % {
- "curr": CurrentStateTable.table_name,
- "fields": CurrentStateTable.get_fields_string(),
- "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
- },
- CurrentStateTable.EntryType(
- *(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
- )
- )
- else:
- logger.debug("handle_new_state not current")
-
- logger.debug("handle_new_state done")
-
- return is_current
-
- @log_function
- def _enumerate_state_branches(self, txn, pdu_a, pdu_b):
- branch_a = pdu_a
- branch_b = pdu_b
-
- while True:
- if (branch_a.pdu_id == branch_b.pdu_id
- and branch_a.origin == branch_b.origin):
- # Woo! We found a common ancestor
- logger.debug("_enumerate_state_branches Found common ancestor")
- break
-
- do_branch_a = (
- hasattr(branch_a, "prev_state_id") and
- branch_a.prev_state_id
- )
-
- do_branch_b = (
- hasattr(branch_b, "prev_state_id") and
- branch_b.prev_state_id
- )
-
- logger.debug(
- "do_branch_a=%s, do_branch_b=%s",
- do_branch_a, do_branch_b
- )
-
- if do_branch_a and do_branch_b:
- do_branch_a = int(branch_a.depth) > int(branch_b.depth)
-
- if do_branch_a:
- pdu_tuple = PduIdTuple(
- branch_a.prev_state_id,
- branch_a.prev_state_origin
- )
-
- prev_branch = branch_a
-
- logger.debug("getting branch_a prev %s", pdu_tuple)
- branch_a = self._get_pdu_tuple(txn, *pdu_tuple)
- if branch_a:
- branch_a = Pdu.from_pdu_tuple(branch_a)
-
- logger.debug("branch_a=%s", branch_a)
-
- yield (0, prev_branch, branch_a)
-
- if not branch_a:
- break
- elif do_branch_b:
- pdu_tuple = PduIdTuple(
- branch_b.prev_state_id,
- branch_b.prev_state_origin
- )
-
- prev_branch = branch_b
-
- logger.debug("getting branch_b prev %s", pdu_tuple)
- branch_b = self._get_pdu_tuple(txn, *pdu_tuple)
- if branch_b:
- branch_b = Pdu.from_pdu_tuple(branch_b)
-
- logger.debug("branch_b=%s", branch_b)
-
- yield (1, prev_branch, branch_b)
-
- if not branch_b:
- break
- else:
- break
-
-
-class PdusTable(Table):
- table_name = "pdus"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- "pdu_type",
- "ts",
- "depth",
- "is_state",
- "content_json",
- "unrecognized_keys",
- "outlier",
- "have_processed",
- ]
-
- EntryType = namedtuple("PdusEntry", fields)
-
-
-class PduDestinationsTable(Table):
- table_name = "pdu_destinations"
-
- fields = [
- "pdu_id",
- "origin",
- "destination",
- "delivered_ts",
- ]
-
- EntryType = namedtuple("PduDestinationsEntry", fields)
-
-
-class PduEdgesTable(Table):
- table_name = "pdu_edges"
-
- fields = [
- "pdu_id",
- "origin",
- "prev_pdu_id",
- "prev_origin",
- "context"
- ]
-
- EntryType = namedtuple("PduEdgesEntry", fields)
-
-
-class PduForwardExtremitiesTable(Table):
- table_name = "pdu_forward_extremities"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- ]
-
- EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
-
-
-class PduBackwardExtremitiesTable(Table):
- table_name = "pdu_backward_extremities"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- ]
-
- EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
-
-
-class ContextDepthTable(Table):
- table_name = "context_depth"
-
- fields = [
- "context",
- "min_depth",
- ]
-
- EntryType = namedtuple("ContextDepthEntry", fields)
-
-
-class StatePdusTable(Table):
- table_name = "state_pdus"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- "pdu_type",
- "state_key",
- "power_level",
- "prev_state_id",
- "prev_state_origin",
- ]
-
- EntryType = namedtuple("StatePdusEntry", fields)
-
-
-class CurrentStateTable(Table):
- table_name = "current_state"
-
- fields = [
- "pdu_id",
- "origin",
- "context",
- "pdu_type",
- "state_key",
- ]
-
- EntryType = namedtuple("CurrentStateEntry", fields)
-
-_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
-
-
-# TODO: These should probably be put somewhere more sensible
-PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
-
-PduEntry = _pdu_state_joiner.EntryType
-""" We are always interested in the join of the PdusTable and StatePdusTable,
-rather than just the PdusTable.
-
-This does not include a prev_pdus key.
-"""
-
-PduTuple = namedtuple(
- "PduTuple",
- ("pdu_entry", "prev_pdu_list")
-)
-""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
-the `prev_pdus` key of a PDU.
-"""
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 719806f82b..1f89d77344 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -62,8 +62,10 @@ class RegistrationStore(SQLBaseStore):
Raises:
StoreError if the user_id could not be registered.
"""
- yield self.runInteraction(self._register, user_id, token,
- password_hash)
+ yield self.runInteraction(
+ "register",
+ self._register, user_id, token, password_hash
+ )
def _register(self, txn, user_id, token, password_hash):
now = int(self.clock.time())
@@ -100,17 +102,22 @@ class RegistrationStore(SQLBaseStore):
StoreError if no user was found.
"""
return self.runInteraction(
+ "get_user_by_token",
self._query_for_auth,
token
)
+ @defer.inlineCallbacks
def is_server_admin(self, user):
- return self._simple_select_one_onecol(
+ res = yield self._simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
+ allow_none=True,
)
+ defer.returnValue(res if res else False)
+
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.admin, access_tokens.device_id "
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 8cd46334cf..cc0513b8d2 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -132,209 +132,29 @@ class RoomStore(SQLBaseStore):
defer.returnValue(ret)
- @defer.inlineCallbacks
- def get_room_join_rule(self, room_id):
- sql = (
- "SELECT join_rule FROM room_join_rules as r "
- "INNER JOIN current_state_events as c "
- "ON r.event_id = c.event_id "
- "WHERE c.room_id = ? "
- )
-
- rows = yield self._execute(None, sql, room_id)
-
- if len(rows) == 1:
- defer.returnValue(rows[0][0])
- else:
- defer.returnValue(None)
-
- def get_power_level(self, room_id, user_id):
- return self.runInteraction(
- self._get_power_level,
- room_id, user_id,
- )
-
- def _get_power_level(self, txn, room_id, user_id):
- sql = (
- "SELECT level FROM room_power_levels as r "
- "INNER JOIN current_state_events as c "
- "ON r.event_id = c.event_id "
- "WHERE c.room_id = ? AND r.user_id = ? "
- )
-
- rows = txn.execute(sql, (room_id, user_id,)).fetchall()
-
- if len(rows) == 1:
- return rows[0][0]
-
- sql = (
- "SELECT level FROM room_default_levels as r "
- "INNER JOIN current_state_events as c "
- "ON r.event_id = c.event_id "
- "WHERE c.room_id = ? "
- )
-
- rows = txn.execute(sql, (room_id,)).fetchall()
-
- if len(rows) == 1:
- return rows[0][0]
- else:
- return None
-
- def get_ops_levels(self, room_id):
- return self.runInteraction(
- self._get_ops_levels,
- room_id,
- )
-
- def _get_ops_levels(self, txn, room_id):
- sql = (
- "SELECT ban_level, kick_level, redact_level "
- "FROM room_ops_levels as r "
- "INNER JOIN current_state_events as c "
- "ON r.event_id = c.event_id "
- "WHERE c.room_id = ? "
- )
-
- rows = txn.execute(sql, (room_id,)).fetchall()
-
- if len(rows) == 1:
- return OpsLevel(rows[0][0], rows[0][1], rows[0][2])
- else:
- return OpsLevel(None, None)
-
- def get_add_state_level(self, room_id):
- return self._get_level_from_table("room_add_state_levels", room_id)
-
- def get_send_event_level(self, room_id):
- return self._get_level_from_table("room_send_event_levels", room_id)
-
- @defer.inlineCallbacks
- def _get_level_from_table(self, table, room_id):
- sql = (
- "SELECT level FROM %(table)s as r "
- "INNER JOIN current_state_events as c "
- "ON r.event_id = c.event_id "
- "WHERE c.room_id = ? "
- ) % {"table": table}
-
- rows = yield self._execute(None, sql, room_id)
-
- if len(rows) == 1:
- defer.returnValue(rows[0][0])
- else:
- defer.returnValue(None)
-
def _store_room_topic_txn(self, txn, event):
- self._simple_insert_txn(
- txn,
- "topics",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "topic": event.topic,
- }
- )
+ if hasattr(event, "topic"):
+ self._simple_insert_txn(
+ txn,
+ "topics",
+ {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "topic": event.topic,
+ }
+ )
def _store_room_name_txn(self, txn, event):
- self._simple_insert_txn(
- txn,
- "room_names",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "name": event.name,
- }
- )
-
- def _store_join_rule(self, txn, event):
- self._simple_insert_txn(
- txn,
- "room_join_rules",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "join_rule": event.content["join_rule"],
- },
- )
-
- def _store_power_levels(self, txn, event):
- for user_id, level in event.content.items():
- if user_id == "default":
- self._simple_insert_txn(
- txn,
- "room_default_levels",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "level": level,
- },
- )
- else:
- self._simple_insert_txn(
- txn,
- "room_power_levels",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "user_id": user_id,
- "level": level
- },
- )
-
- def _store_default_level(self, txn, event):
- self._simple_insert_txn(
- txn,
- "room_default_levels",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "level": event.content["default_level"],
- },
- )
-
- def _store_add_state_level(self, txn, event):
- self._simple_insert_txn(
- txn,
- "room_add_state_levels",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "level": event.content["level"],
- },
- )
-
- def _store_send_event_level(self, txn, event):
- self._simple_insert_txn(
- txn,
- "room_send_event_levels",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "level": event.content["level"],
- },
- )
-
- def _store_ops_level(self, txn, event):
- content = {
- "event_id": event.event_id,
- "room_id": event.room_id,
- }
-
- if "kick_level" in event.content:
- content["kick_level"] = event.content["kick_level"]
-
- if "ban_level" in event.content:
- content["ban_level"] = event.content["ban_level"]
-
- if "redact_level" in event.content:
- content["redact_level"] = event.content["redact_level"]
-
- self._simple_insert_txn(
- txn,
- "room_ops_levels",
- content,
- )
+ if hasattr(event, "name"):
+ self._simple_insert_txn(
+ txn,
+ "room_names",
+ {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "name": event.name,
+ }
+ )
class RoomsTable(Table):
diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql
deleted file mode 100644
index 8a00868065..0000000000
--- a/synapse/storage/schema/edge_pdus.sql
+++ /dev/null
@@ -1,31 +0,0 @@
-/* Copyright 2014 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-CREATE TABLE IF NOT EXISTS context_edge_pdus(
- id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
-CREATE TABLE IF NOT EXISTS origin_edge_pdus(
- id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
- pdu_id TEXT,
- origin TEXT,
- CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
-CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin);
-CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin);
diff --git a/synapse/storage/schema/event_edges.sql b/synapse/storage/schema/event_edges.sql
new file mode 100644
index 0000000000..be1c72a775
--- /dev/null
+++ b/synapse/storage/schema/event_edges.sql
@@ -0,0 +1,75 @@
+
+CREATE TABLE IF NOT EXISTS event_forward_extremities(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS ev_extrem_room ON event_forward_extremities(room_id);
+CREATE INDEX IF NOT EXISTS ev_extrem_id ON event_forward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_backward_extremities(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS ev_b_extrem_room ON event_backward_extremities(room_id);
+CREATE INDEX IF NOT EXISTS ev_b_extrem_id ON event_backward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_edges(
+ event_id TEXT NOT NULL,
+ prev_event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ is_state INTEGER NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (event_id, prev_event_id, room_id, is_state)
+);
+
+CREATE INDEX IF NOT EXISTS ev_edges_id ON event_edges(event_id);
+CREATE INDEX IF NOT EXISTS ev_edges_prev_id ON event_edges(prev_event_id);
+
+
+CREATE TABLE IF NOT EXISTS room_depth(
+ room_id TEXT NOT NULL,
+ min_depth INTEGER NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (room_id)
+);
+
+CREATE INDEX IF NOT EXISTS room_depth_room ON room_depth(room_id);
+
+
+create TABLE IF NOT EXISTS event_destinations(
+ event_id TEXT NOT NULL,
+ destination TEXT NOT NULL,
+ delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
+ CONSTRAINT uniqueness UNIQUE (event_id, destination) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS event_destinations_id ON event_destinations(event_id);
+
+
+CREATE TABLE IF NOT EXISTS state_forward_extremities(
+ event_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (event_id, room_id) ON CONFLICT REPLACE
+);
+
+CREATE INDEX IF NOT EXISTS st_extrem_keys ON state_forward_extremities(
+ room_id, type, state_key
+);
+CREATE INDEX IF NOT EXISTS st_extrem_id ON state_forward_extremities(event_id);
+
+
+CREATE TABLE IF NOT EXISTS event_auth(
+ event_id TEXT NOT NULL,
+ auth_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ CONSTRAINT uniqueness UNIQUE (event_id, auth_id, room_id)
+);
+
+CREATE INDEX IF NOT EXISTS evauth_edges_id ON event_auth(event_id);
+CREATE INDEX IF NOT EXISTS evauth_edges_auth_id ON event_auth(auth_id);
\ No newline at end of file
diff --git a/synapse/storage/schema/event_signatures.sql b/synapse/storage/schema/event_signatures.sql
new file mode 100644
index 0000000000..4efa8a3e63
--- /dev/null
+++ b/synapse/storage/schema/event_signatures.sql
@@ -0,0 +1,65 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_content_hashes (
+ event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS event_content_hashes_id ON event_content_hashes(
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_reference_hashes (
+ event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, algorithm)
+);
+
+CREATE INDEX IF NOT EXISTS event_reference_hashes_id ON event_reference_hashes (
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_signatures (
+ event_id TEXT,
+ signature_name TEXT,
+ key_id TEXT,
+ signature BLOB,
+ CONSTRAINT uniqueness UNIQUE (event_id, key_id)
+);
+
+CREATE INDEX IF NOT EXISTS event_signatures_id ON event_signatures (
+ event_id
+);
+
+
+CREATE TABLE IF NOT EXISTS event_edge_hashes(
+ event_id TEXT,
+ prev_event_id TEXT,
+ algorithm TEXT,
+ hash BLOB,
+ CONSTRAINT uniqueness UNIQUE (
+ event_id, prev_event_id, algorithm
+ )
+);
+
+CREATE INDEX IF NOT EXISTS event_edge_hashes_id ON event_edge_hashes(
+ event_id
+);
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index 3aa83f5c8c..8ba732a23b 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -23,6 +23,7 @@ CREATE TABLE IF NOT EXISTS events(
unrecognized_keys TEXT,
processed BOOL NOT NULL,
outlier BOOL NOT NULL,
+ depth INTEGER DEFAULT 0 NOT NULL,
CONSTRAINT ev_uniq UNIQUE (event_id)
);
@@ -84,80 +85,24 @@ CREATE TABLE IF NOT EXISTS topics(
topic TEXT NOT NULL
);
+CREATE INDEX IF NOT EXISTS topics_event_id ON topics(event_id);
+CREATE INDEX IF NOT EXISTS topics_room_id ON topics(room_id);
+
CREATE TABLE IF NOT EXISTS room_names(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
name TEXT NOT NULL
);
+CREATE INDEX IF NOT EXISTS room_names_event_id ON room_names(event_id);
+CREATE INDEX IF NOT EXISTS room_names_room_id ON room_names(room_id);
+
CREATE TABLE IF NOT EXISTS rooms(
room_id TEXT PRIMARY KEY NOT NULL,
is_public INTEGER,
creator TEXT
);
-CREATE TABLE IF NOT EXISTS room_join_rules(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- join_rule TEXT NOT NULL
-);
-CREATE INDEX IF NOT EXISTS room_join_rules_event_id ON room_join_rules(event_id);
-CREATE INDEX IF NOT EXISTS room_join_rules_room_id ON room_join_rules(room_id);
-
-
-CREATE TABLE IF NOT EXISTS room_power_levels(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- user_id TEXT NOT NULL,
- level INTEGER NOT NULL
-);
-CREATE INDEX IF NOT EXISTS room_power_levels_event_id ON room_power_levels(event_id);
-CREATE INDEX IF NOT EXISTS room_power_levels_room_id ON room_power_levels(room_id);
-CREATE INDEX IF NOT EXISTS room_power_levels_room_user ON room_power_levels(room_id, user_id);
-
-
-CREATE TABLE IF NOT EXISTS room_default_levels(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- level INTEGER NOT NULL
-);
-
-CREATE INDEX IF NOT EXISTS room_default_levels_event_id ON room_default_levels(event_id);
-CREATE INDEX IF NOT EXISTS room_default_levels_room_id ON room_default_levels(room_id);
-
-
-CREATE TABLE IF NOT EXISTS room_add_state_levels(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- level INTEGER NOT NULL
-);
-
-CREATE INDEX IF NOT EXISTS room_add_state_levels_event_id ON room_add_state_levels(event_id);
-CREATE INDEX IF NOT EXISTS room_add_state_levels_room_id ON room_add_state_levels(room_id);
-
-
-CREATE TABLE IF NOT EXISTS room_send_event_levels(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- level INTEGER NOT NULL
-);
-
-CREATE INDEX IF NOT EXISTS room_send_event_levels_event_id ON room_send_event_levels(event_id);
-CREATE INDEX IF NOT EXISTS room_send_event_levels_room_id ON room_send_event_levels(room_id);
-
-
-CREATE TABLE IF NOT EXISTS room_ops_levels(
- event_id TEXT NOT NULL,
- room_id TEXT NOT NULL,
- ban_level INTEGER,
- kick_level INTEGER,
- redact_level INTEGER
-);
-
-CREATE INDEX IF NOT EXISTS room_ops_levels_event_id ON room_ops_levels(event_id);
-CREATE INDEX IF NOT EXISTS room_ops_levels_room_id ON room_ops_levels(room_id);
-
-
CREATE TABLE IF NOT EXISTS room_hosts(
room_id TEXT NOT NULL,
host TEXT NOT NULL,
diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql
deleted file mode 100644
index 16e111a56c..0000000000
--- a/synapse/storage/schema/pdu.sql
+++ /dev/null
@@ -1,106 +0,0 @@
-/* Copyright 2014 OpenMarket Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
--- Stores pdus and their content
-CREATE TABLE IF NOT EXISTS pdus(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- pdu_type TEXT,
- ts INTEGER,
- depth INTEGER DEFAULT 0 NOT NULL,
- is_state BOOL,
- content_json TEXT,
- unrecognized_keys TEXT,
- outlier BOOL NOT NULL,
- have_processed BOOL,
- CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
-);
-
--- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
-CREATE TABLE IF NOT EXISTS state_pdus(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- pdu_type TEXT,
- state_key TEXT,
- power_level TEXT,
- prev_state_id TEXT,
- prev_state_origin TEXT,
- CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
- CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
-);
-
-CREATE TABLE IF NOT EXISTS current_state(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- pdu_type TEXT,
- state_key TEXT,
- CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
- CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
-);
-
--- Stores where each pdu we want to send should be sent and the delivery status.
-create TABLE IF NOT EXISTS pdu_destinations(
- pdu_id TEXT,
- origin TEXT,
- destination TEXT,
- delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
- pdu_id TEXT,
- origin TEXT,
- context TEXT,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
-);
-
-CREATE TABLE IF NOT EXISTS pdu_edges(
- pdu_id TEXT,
- origin TEXT,
- prev_pdu_id TEXT,
- prev_origin TEXT,
- context TEXT,
- CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
-);
-
-CREATE TABLE IF NOT EXISTS context_depth(
- context TEXT,
- min_depth INTEGER,
- CONSTRAINT uniqueness UNIQUE (context)
-);
-
-CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
-
-
-CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
--- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
-
-CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
-CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
-
-CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);
diff --git a/synapse/storage/schema/state.sql b/synapse/storage/schema/state.sql
new file mode 100644
index 0000000000..44f7aafb27
--- /dev/null
+++ b/synapse/storage/schema/state.sql
@@ -0,0 +1,46 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS state_groups(
+ id INTEGER PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS state_groups_state(
+ state_group INTEGER NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS event_to_state_groups(
+ event_id TEXT NOT NULL,
+ state_group INTEGER NOT NULL
+);
+
+CREATE INDEX IF NOT EXISTS state_groups_id ON state_groups(id);
+
+CREATE INDEX IF NOT EXISTS state_groups_state_id ON state_groups_state(
+ state_group
+);
+CREATE INDEX IF NOT EXISTS state_groups_state_tuple ON state_groups_state(
+ room_id, type, state_key
+);
+
+CREATE INDEX IF NOT EXISTS event_to_state_groups_id ON event_to_state_groups(
+ event_id
+);
\ No newline at end of file
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
new file mode 100644
index 0000000000..d90e08fff1
--- /dev/null
+++ b/synapse/storage/signatures.py
@@ -0,0 +1,183 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from _base import SQLBaseStore
+
+
+class SignatureStore(SQLBaseStore):
+ """Persistence for event signatures and hashes"""
+
+ def _get_event_content_hashes_txn(self, txn, event_id):
+ """Get all the hashes for a given Event.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ A dict of algorithm -> hash.
+ """
+ query = (
+ "SELECT algorithm, hash"
+ " FROM event_content_hashes"
+ " WHERE event_id = ?"
+ )
+ txn.execute(query, (event_id, ))
+ return dict(txn.fetchall())
+
+ def _store_event_content_hash_txn(self, txn, event_id, algorithm,
+ hash_bytes):
+ """Store a hash for a Event
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ algorithm (str): Hashing algorithm.
+ hash_bytes (bytes): Hash function output bytes.
+ """
+ self._simple_insert_txn(
+ txn,
+ "event_content_hashes",
+ {
+ "event_id": event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
+
+ def get_event_reference_hashes(self, event_ids):
+ def f(txn):
+ return [
+ self._get_event_reference_hashes_txn(txn, ev)
+ for ev in event_ids
+ ]
+
+ return self.runInteraction(
+ "get_event_reference_hashes",
+ f
+ )
+
+ def _get_event_reference_hashes_txn(self, txn, event_id):
+ """Get all the hashes for a given PDU.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ A dict of algorithm -> hash.
+ """
+ query = (
+ "SELECT algorithm, hash"
+ " FROM event_reference_hashes"
+ " WHERE event_id = ?"
+ )
+ txn.execute(query, (event_id, ))
+ return dict(txn.fetchall())
+
+ def _store_event_reference_hash_txn(self, txn, event_id, algorithm,
+ hash_bytes):
+ """Store a hash for a PDU
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ algorithm (str): Hashing algorithm.
+ hash_bytes (bytes): Hash function output bytes.
+ """
+ self._simple_insert_txn(
+ txn,
+ "event_reference_hashes",
+ {
+ "event_id": event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
+
+ def _get_event_signatures_txn(self, txn, event_id):
+ """Get all the signatures for a given PDU.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ A dict of sig name -> dict(key_id -> signature_bytes)
+ """
+ query = (
+ "SELECT signature_name, key_id, signature"
+ " FROM event_signatures"
+ " WHERE event_id = ? "
+ )
+ txn.execute(query, (event_id, ))
+ rows = txn.fetchall()
+
+ res = {}
+
+ for name, key, sig in rows:
+ res.setdefault(name, {})[key] = sig
+
+ return res
+
+ def _store_event_signature_txn(self, txn, event_id, signature_name, key_id,
+ signature_bytes):
+ """Store a signature from the origin server for a PDU.
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ origin (str): origin of the Event.
+ key_id (str): Id for the signing key.
+ signature (bytes): The signature.
+ """
+ self._simple_insert_txn(
+ txn,
+ "event_signatures",
+ {
+ "event_id": event_id,
+ "signature_name": signature_name,
+ "key_id": key_id,
+ "signature": buffer(signature_bytes),
+ },
+ or_ignore=True,
+ )
+
+ def _get_prev_event_hashes_txn(self, txn, event_id):
+ """Get all the hashes for previous PDUs of a PDU
+ Args:
+ txn (cursor):
+ event_id (str): Id for the Event.
+ Returns:
+ dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
+ """
+ query = (
+ "SELECT prev_event_id, algorithm, hash"
+ " FROM event_edge_hashes"
+ " WHERE event_id = ?"
+ )
+ txn.execute(query, (event_id, ))
+ results = {}
+ for prev_event_id, algorithm, hash_bytes in txn.fetchall():
+ hashes = results.setdefault(prev_event_id, {})
+ hashes[algorithm] = hash_bytes
+ return results
+
+ def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
+ algorithm, hash_bytes):
+ self._simple_insert_txn(
+ txn,
+ "event_edge_hashes",
+ {
+ "event_id": event_id,
+ "prev_event_id": prev_event_id,
+ "algorithm": algorithm,
+ "hash": buffer(hash_bytes),
+ },
+ or_ignore=True,
+ )
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
new file mode 100644
index 0000000000..55ea567793
--- /dev/null
+++ b/synapse/storage/state.py
@@ -0,0 +1,127 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+
+
+class StateStore(SQLBaseStore):
+ """ Keeps track of the state at a given event.
+
+ This is done by the concept of `state groups`. Every event is a assigned
+ a state group (identified by an arbitrary string), which references a
+ collection of state events. The current state of an event is then the
+ collection of state events referenced by the event's state group.
+
+ Hence, every change in the current state causes a new state group to be
+ generated. However, if no change happens (e.g., if we get a message event
+ with only one parent it inherits the state group from its parent.)
+
+ There are three tables:
+ * `state_groups`: Stores group name, first event with in the group and
+ room id.
+ * `event_to_state_groups`: Maps events to state groups.
+ * `state_groups_state`: Maps state group to state events.
+ """
+
+ def get_state_groups(self, event_ids):
+ """ Get the state groups for the given list of event_ids
+
+ The return value is a dict mapping group names to lists of events.
+ """
+
+ def f(txn):
+ groups = set()
+ for event_id in event_ids:
+ group = self._simple_select_one_onecol_txn(
+ txn,
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
+ allow_none=True,
+ )
+ if group:
+ groups.add(group)
+
+ res = {}
+ for group in groups:
+ state_ids = self._simple_select_onecol_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": group},
+ retcol="event_id",
+ )
+ state = []
+ for state_id in state_ids:
+ s = self._get_events_txn(
+ txn,
+ [state_id],
+ )
+ if s:
+ state.extend(s)
+
+ res[group] = state
+
+ return res
+
+ return self.runInteraction(
+ "get_state_groups",
+ f,
+ )
+
+ def store_state_groups(self, event):
+ return self.runInteraction(
+ "store_state_groups",
+ self._store_state_groups_txn, event
+ )
+
+ def _store_state_groups_txn(self, txn, event):
+ if not event.state_events:
+ return
+
+ state_group = event.state_group
+ if not state_group:
+ state_group = self._simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ },
+ or_ignore=True,
+ )
+
+ for state in event.state_events.values():
+ self._simple_insert_txn(
+ txn,
+ table="state_groups_state",
+ values={
+ "state_group": state_group,
+ "room_id": state.room_id,
+ "type": state.type,
+ "state_key": state.state_key,
+ "event_id": state.event_id,
+ },
+ or_ignore=True,
+ )
+
+ self._simple_insert_txn(
+ txn,
+ table="event_to_state_groups",
+ values={
+ "state_group": state_group,
+ "event_id": event.event_id,
+ },
+ or_replace=True,
+ )
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d61f909939..475e7f20a1 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -177,10 +177,9 @@ class StreamStore(SQLBaseStore):
sql = (
"SELECT *, (%(redacted)s) AS redacted FROM events AS e WHERE "
- "((room_id IN (%(current)s)) OR "
+ "(e.outlier = 0 AND (room_id IN (%(current)s)) OR "
"(event_id IN (%(invites)s))) "
"AND e.stream_ordering > ? AND e.stream_ordering <= ? "
- "AND e.outlier = 0 "
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
) % {
"redacted": del_sql,
@@ -309,7 +308,10 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
def get_room_events_max_id(self):
- return self.runInteraction(self._get_room_events_max_id_txn)
+ return self.runInteraction(
+ "get_room_events_max_id",
+ self._get_room_events_max_id_txn
+ )
def _get_room_events_max_id_txn(self, txn):
txn.execute(
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 2ba8e30efe..00d0f48082 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -14,7 +14,6 @@
# limitations under the License.
from ._base import SQLBaseStore, Table
-from .pdu import PdusTable
from collections import namedtuple
@@ -42,6 +41,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
+ "get_received_txn_response",
self._get_received_txn_response, transaction_id, origin
)
@@ -73,6 +73,7 @@ class TransactionStore(SQLBaseStore):
"""
return self.runInteraction(
+ "set_received_txn_response",
self._set_received_txn_response,
transaction_id, origin, code, response_dict
)
@@ -88,7 +89,7 @@ class TransactionStore(SQLBaseStore):
txn.execute(query, (code, response_json, transaction_id, origin))
def prep_send_transaction(self, transaction_id, destination,
- origin_server_ts, pdu_list):
+ origin_server_ts):
"""Persists an outgoing transaction and calculates the values for the
previous transaction id list.
@@ -99,19 +100,19 @@ class TransactionStore(SQLBaseStore):
transaction_id (str)
destination (str)
origin_server_ts (int)
- pdu_list (list)
Returns:
list: A list of previous transaction ids.
"""
return self.runInteraction(
+ "prep_send_transaction",
self._prep_send_transaction,
- transaction_id, destination, origin_server_ts, pdu_list
+ transaction_id, destination, origin_server_ts
)
def _prep_send_transaction(self, txn, transaction_id, destination,
- origin_server_ts, pdu_list):
+ origin_server_ts):
# First we find out what the prev_txs should be.
# Since we know that we are only sending one transaction at a time,
@@ -139,15 +140,15 @@ class TransactionStore(SQLBaseStore):
# Update the tx id -> pdu id mapping
- values = [
- (transaction_id, destination, pdu[0], pdu[1])
- for pdu in pdu_list
- ]
-
- logger.debug("Inserting: %s", repr(values))
-
- query = TransactionsToPduTable.insert_statement()
- txn.executemany(query, values)
+ # values = [
+ # (transaction_id, destination, pdu[0], pdu[1])
+ # for pdu in pdu_list
+ # ]
+ #
+ # logger.debug("Inserting: %s", repr(values))
+ #
+ # query = TransactionsToPduTable.insert_statement()
+ # txn.executemany(query, values)
return prev_txns
@@ -161,6 +162,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
return self.runInteraction(
+ "delivered_txn",
self._delivered_txn,
transaction_id, destination, code, response_dict
)
@@ -186,6 +188,7 @@ class TransactionStore(SQLBaseStore):
list: A list of `ReceivedTransactionsTable.EntryType`
"""
return self.runInteraction(
+ "get_transactions_after",
self._get_transactions_after, transaction_id, destination
)
@@ -202,49 +205,6 @@ class TransactionStore(SQLBaseStore):
return ReceivedTransactionsTable.decode_results(txn.fetchall())
- def get_pdus_after_transaction(self, transaction_id, destination):
- """For a given local transaction_id that we sent to a given destination
- home server, return a list of PDUs that were sent to that destination
- after it.
-
- Args:
- txn
- transaction_id (str)
- destination (str)
-
- Returns
- list: A list of PduTuple
- """
- return self.runInteraction(
- self._get_pdus_after_transaction,
- transaction_id, destination
- )
-
- def _get_pdus_after_transaction(self, txn, transaction_id, destination):
-
- # Query that first get's all transaction_ids with an id greater than
- # the one given from the `sent_transactions` table. Then JOIN on this
- # from the `tx->pdu` table to get a list of (pdu_id, origin) that
- # specify the pdus that were sent in those transactions.
- query = (
- "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
- "INNER JOIN %(sent_tx)s as st "
- "ON tp.transaction_id = st.transaction_id "
- "AND tp.destination = st.destination "
- "WHERE st.id > ("
- "SELECT id FROM %(sent_tx)s "
- "WHERE transaction_id = ? AND destination = ?"
- ) % {
- "tx_pdu": TransactionsToPduTable.table_name,
- "sent_tx": SentTransactions.table_name,
- }
-
- txn.execute(query, (transaction_id, destination))
-
- pdus = PdusTable.decode_results(txn.fetchall())
-
- return self._get_pdu_tuples(txn, pdus)
-
class ReceivedTransactionsTable(Table):
table_name = "received_transactions"
|