diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
new file mode 100644
index 0000000000..ec93f9f8a7
--- /dev/null
+++ b/synapse/storage/__init__.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.api.events.room import (
+ RoomMemberEvent, MessageEvent, RoomTopicEvent, FeedbackEvent,
+ RoomConfigEvent
+)
+
+from .directory import DirectoryStore
+from .feedback import FeedbackStore
+from .message import MessageStore
+from .presence import PresenceStore
+from .profile import ProfileStore
+from .registration import RegistrationStore
+from .room import RoomStore
+from .roommember import RoomMemberStore
+from .roomdata import RoomDataStore
+from .stream import StreamStore
+from .pdu import StatePduStore, PduStore
+from .transactions import TransactionStore
+
+import json
+import os
+
+
+class DataStore(RoomDataStore, RoomMemberStore, MessageStore, RoomStore,
+ RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
+ PresenceStore, PduStore, StatePduStore, TransactionStore,
+ DirectoryStore):
+
+ def __init__(self, hs):
+ super(DataStore, self).__init__(hs)
+ self.event_factory = hs.get_event_factory()
+ self.hs = hs
+
+ def persist_event(self, event):
+ if event.type == MessageEvent.TYPE:
+ return self.store_message(
+ user_id=event.user_id,
+ room_id=event.room_id,
+ msg_id=event.msg_id,
+ content=json.dumps(event.content)
+ )
+ elif event.type == RoomMemberEvent.TYPE:
+ return self.store_room_member(
+ user_id=event.target_user_id,
+ sender=event.user_id,
+ room_id=event.room_id,
+ content=event.content,
+ membership=event.content["membership"]
+ )
+ elif event.type == FeedbackEvent.TYPE:
+ return self.store_feedback(
+ room_id=event.room_id,
+ msg_id=event.msg_id,
+ msg_sender_id=event.msg_sender_id,
+ fb_sender_id=event.user_id,
+ fb_type=event.feedback_type,
+ content=json.dumps(event.content)
+ )
+ elif event.type == RoomTopicEvent.TYPE:
+ return self.store_room_data(
+ room_id=event.room_id,
+ etype=event.type,
+ state_key=event.state_key,
+ content=json.dumps(event.content)
+ )
+ elif event.type == RoomConfigEvent.TYPE:
+ if "visibility" in event.content:
+ visibility = event.content["visibility"]
+ return self.store_room_config(
+ room_id=event.room_id,
+ visibility=visibility
+ )
+
+ else:
+ raise NotImplementedError(
+ "Don't know how to persist type=%s" % event.type
+ )
+
+
+def schema_path(schema):
+ """ Get a filesystem path for the named database schema
+
+ Args:
+ schema: Name of the database schema.
+ Returns:
+ A filesystem path pointing at a ".sql" file.
+
+ """
+ dir_path = os.path.dirname(__file__)
+ schemaPath = os.path.join(dir_path, "schema", schema + ".sql")
+ return schemaPath
+
+
+def read_schema(schema):
+ """ Read the named database schema.
+
+ Args:
+ schema: Name of the datbase schema.
+ Returns:
+ A string containing the database schema.
+ """
+ with open(schema_path(schema)) as schema_file:
+ return schema_file.read()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
new file mode 100644
index 0000000000..4d98a6fd0d
--- /dev/null
+++ b/synapse/storage/_base.py
@@ -0,0 +1,405 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+
+import collections
+
+logger = logging.getLogger(__name__)
+
+
+class SQLBaseStore(object):
+
+ def __init__(self, hs):
+ self._db_pool = hs.get_db_pool()
+
+ def cursor_to_dict(self, cursor):
+ """Converts a SQL cursor into an list of dicts.
+
+ Args:
+ cursor : The DBAPI cursor which has executed a query.
+ Returns:
+ A list of dicts where the key is the column header.
+ """
+ col_headers = list(column[0] for column in cursor.description)
+ results = list(
+ dict(zip(col_headers, row)) for row in cursor.fetchall()
+ )
+ return results
+
+ def _execute(self, decoder, query, *args):
+ """Runs a single query for a result set.
+
+ Args:
+ decoder - The function which can resolve the cursor results to
+ something meaningful.
+ query - The query string to execute
+ *args - Query args.
+ Returns:
+ The result of decoder(results)
+ """
+ logger.debug(
+ "[SQL] %s Args=%s Func=%s", query, args, decoder.__name__
+ )
+
+ def interaction(txn):
+ cursor = txn.execute(query, args)
+ return decoder(cursor)
+ return self._db_pool.runInteraction(interaction)
+
+ # "Simple" SQL API methods that operate on a single table with no JOINs,
+ # no complex WHERE clauses, just a dict of values for columns.
+
+ def _simple_insert(self, table, values):
+ """Executes an INSERT query on the named table.
+
+ Args:
+ table : string giving the table name
+ values : dict of new column names and values for them
+ """
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in values),
+ ", ".join("?" for k in values)
+ )
+
+ def func(txn):
+ txn.execute(sql, values.values())
+ return txn.lastrowid
+ return self._db_pool.runInteraction(func)
+
+ def _simple_select_one(self, table, keyvalues, retcols,
+ allow_none=False):
+ """Executes a SELECT query on the named table, which is expected to
+ return a single row, returning a single column from it.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ retcols : list of strings giving the names of the columns to return
+
+ allow_none : If true, return None instead of failing if the SELECT
+ statement returns no rows
+ """
+ return self._simple_selectupdate_one(
+ table, keyvalues, retcols=retcols, allow_none=allow_none
+ )
+
+ @defer.inlineCallbacks
+ def _simple_select_one_onecol(self, table, keyvalues, retcol,
+ allow_none=False):
+ """Executes a SELECT query on the named table, which is expected to
+ return a single row, returning a single column from it."
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ retcol : string giving the name of the column to return
+ """
+ ret = yield self._simple_select_one(
+ table=table,
+ keyvalues=keyvalues,
+ retcols=[retcol],
+ allow_none=allow_none
+ )
+
+ if ret:
+ defer.returnValue(ret[retcol])
+ else:
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def _simple_select_onecol(self, table, keyvalues, retcol):
+ """Executes a SELECT query on the named table, which returns a list
+ comprising of the values of the named column from the selected rows.
+
+ Args:
+ table (str): table name
+ keyvalues (dict): column names and values to select the rows with
+ retcol (str): column whos value we wish to retrieve.
+
+ Returns:
+ Deferred: Results in a list
+ """
+ sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % {
+ "retcol": retcol,
+ "table": table,
+ "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()),
+ }
+
+ def func(txn):
+ txn.execute(sql, keyvalues.values())
+ return txn.fetchall()
+
+ res = yield self._db_pool.runInteraction(func)
+
+ defer.returnValue([r[0] for r in res])
+
+ def _simple_select_list(self, table, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k) for k in keyvalues)
+ )
+
+ def func(txn):
+ txn.execute(sql, keyvalues.values())
+ return self.cursor_to_dict(txn)
+
+ return self._db_pool.runInteraction(func)
+
+ def _simple_update_one(self, table, keyvalues, updatevalues,
+ retcols=None):
+ """Executes an UPDATE query on the named table, setting new values for
+ columns in a row matching the key values.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ updatevalues : dict giving column names and values to update
+ retcols : optional list of column names to return
+
+ If present, retcols gives a list of column names on which to perform
+ a SELECT statement *before* performing the UPDATE statement. The values
+ of these will be returned in a dict.
+
+ These are performed within the same transaction, allowing an atomic
+ get-and-set. This can be used to implement compare-and-set by putting
+ the update column in the 'keyvalues' dict as well.
+ """
+ return self._simple_selectupdate_one(table, keyvalues, updatevalues,
+ retcols=retcols)
+
+ def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
+ retcols=None, allow_none=False):
+ """ Combined SELECT then UPDATE."""
+ if retcols:
+ select_sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k) for k in keyvalues)
+ )
+
+ if updatevalues:
+ update_sql = "UPDATE %s SET %s WHERE %s" % (
+ table,
+ ", ".join("%s = ?" % (k) for k in updatevalues),
+ " AND ".join("%s = ?" % (k) for k in keyvalues)
+ )
+
+ def func(txn):
+ ret = None
+ if retcols:
+ txn.execute(select_sql, keyvalues.values())
+
+ row = txn.fetchone()
+ if not row:
+ if allow_none:
+ return None
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
+ ret = dict(zip(retcols, row))
+
+ if updatevalues:
+ txn.execute(
+ update_sql,
+ updatevalues.values() + keyvalues.values()
+ )
+
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched")
+
+ return ret
+ return self._db_pool.runInteraction(func)
+
+ def _simple_delete_one(self, table, keyvalues):
+ """Executes a DELETE query on the named table, expecting to delete a
+ single row.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k) for k in keyvalues)
+ )
+
+ def func(txn):
+ txn.execute(sql, keyvalues.values())
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found")
+ if txn.rowcount > 1:
+ raise StoreError(500, "more than one row matched")
+ return self._db_pool.runInteraction(func)
+
+ def _simple_max_id(self, table):
+ """Executes a SELECT query on the named table, expecting to return the
+ max value for the column "id".
+
+ Args:
+ table : string giving the table name
+ """
+ sql = "SELECT MAX(id) AS id FROM %s" % table
+
+ def func(txn):
+ txn.execute(sql)
+ max_id = self.cursor_to_dict(txn)[0]["id"]
+ if max_id is None:
+ return 0
+ return max_id
+
+ return self._db_pool.runInteraction(func)
+
+
+class Table(object):
+ """ A base class used to store information about a particular table.
+ """
+
+ table_name = None
+ """ str: The name of the table """
+
+ fields = None
+ """ list: The field names """
+
+ EntryType = None
+ """ Type: A tuple type used to decode the results """
+
+ _select_where_clause = "SELECT %s FROM %s WHERE %s"
+ _select_clause = "SELECT %s FROM %s"
+ _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
+
+ @classmethod
+ def select_statement(cls, where_clause=None):
+ """
+ Args:
+ where_clause (str): The WHERE clause to use.
+
+ Returns:
+ str: An SQL statement to select rows from the table with the given
+ WHERE clause.
+ """
+ if where_clause:
+ return cls._select_where_clause % (
+ ", ".join(cls.fields),
+ cls.table_name,
+ where_clause
+ )
+ else:
+ return cls._select_clause % (
+ ", ".join(cls.fields),
+ cls.table_name,
+ )
+
+ @classmethod
+ def insert_statement(cls):
+ return cls._insert_clause % (
+ cls.table_name,
+ ", ".join(cls.fields),
+ ", ".join(["?"] * len(cls.fields)),
+ )
+
+ @classmethod
+ def decode_single_result(cls, results):
+ """ Given an iterable of tuples, return a single instance of
+ `EntryType` or None if the iterable is empty
+ Args:
+ results (list): The results list to convert to `EntryType`
+ Returns:
+ EntryType: An instance of `EntryType`
+ """
+ results = list(results)
+ if results:
+ return cls.EntryType(*results[0])
+ else:
+ return None
+
+ @classmethod
+ def decode_results(cls, results):
+ """ Given an iterable of tuples, return a list of `EntryType`
+ Args:
+ results (list): The results list to convert to `EntryType`
+
+ Returns:
+ list: A list of `EntryType`
+ """
+ return [cls.EntryType(*row) for row in results]
+
+ @classmethod
+ def get_fields_string(cls, prefix=None):
+ if prefix:
+ to_join = ("%s.%s" % (prefix, f) for f in cls.fields)
+ else:
+ to_join = cls.fields
+
+ return ", ".join(to_join)
+
+
+class JoinHelper(object):
+ """ Used to help do joins on tables by looking at the tables' fields and
+ creating a list of unique fields to use with SELECTs and a namedtuple
+ to dump the results into.
+
+ Attributes:
+ taples (list): List of `Table` classes
+ EntryType (type)
+ """
+
+ def __init__(self, *tables):
+ self.tables = tables
+
+ res = []
+ for table in self.tables:
+ res += [f for f in table.fields if f not in res]
+
+ self.EntryType = collections.namedtuple("JoinHelperEntry", res)
+
+ def get_fields(self, **prefixes):
+ """Get a string representing a list of fields for use in SELECT
+ statements with the given prefixes applied to each.
+
+ For example::
+
+ JoinHelper(PdusTable, StateTable).get_fields(
+ PdusTable="pdus",
+ StateTable="state"
+ )
+ """
+ res = []
+ for field in self.EntryType._fields:
+ for table in self.tables:
+ if field in table.fields:
+ res.append("%s.%s" % (prefixes[table.__name__], field))
+ break
+
+ return ", ".join(res)
+
+ def decode_results(self, rows):
+ return [self.EntryType(*row) for row in rows]
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
new file mode 100644
index 0000000000..71fa9d9c9c
--- /dev/null
+++ b/synapse/storage/directory.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+from twisted.internet import defer
+
+from collections import namedtuple
+
+
+RoomAliasMapping = namedtuple(
+ "RoomAliasMapping",
+ ("room_id", "room_alias", "servers",)
+)
+
+
+class DirectoryStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def get_association_from_room_alias(self, room_alias):
+ """ Get's the room_id and server list for a given room_alias
+
+ Args:
+ room_alias (RoomAlias)
+
+ Returns:
+ Deferred: results in namedtuple with keys "room_id" and
+ "servers" or None if no association can be found
+ """
+ room_id = yield self._simple_select_one_onecol(
+ "room_aliases",
+ {"room_alias": room_alias.to_string()},
+ "room_id",
+ allow_none=True,
+ )
+
+ if not room_id:
+ defer.returnValue(None)
+ return
+
+ servers = yield self._simple_select_onecol(
+ "room_alias_servers",
+ {"room_alias": room_alias.to_string()},
+ "server",
+ )
+
+ if not servers:
+ defer.returnValue(None)
+ return
+
+ defer.returnValue(
+ RoomAliasMapping(room_id, room_alias.to_string(), servers)
+ )
+
+ @defer.inlineCallbacks
+ def create_room_alias_association(self, room_alias, room_id, servers):
+ """ Creates an associatin between a room alias and room_id/servers
+
+ Args:
+ room_alias (RoomAlias)
+ room_id (str)
+ servers (list)
+
+ Returns:
+ Deferred
+ """
+ yield self._simple_insert(
+ "room_aliases",
+ {
+ "room_alias": room_alias.to_string(),
+ "room_id": room_id,
+ },
+ )
+
+ for server in servers:
+ # TODO(erikj): Fix this to bulk insert
+ yield self._simple_insert(
+ "room_alias_servers",
+ {
+ "room_alias": room_alias.to_string(),
+ "server": server,
+ }
+ )
diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py
new file mode 100644
index 0000000000..2b421e3342
--- /dev/null
+++ b/synapse/storage/feedback.py
@@ -0,0 +1,74 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table
+from synapse.api.events.room import FeedbackEvent
+
+import collections
+import json
+
+
+class FeedbackStore(SQLBaseStore):
+
+ def store_feedback(self, room_id, msg_id, msg_sender_id,
+ fb_sender_id, fb_type, content):
+ return self._simple_insert(FeedbackTable.table_name, dict(
+ room_id=room_id,
+ msg_id=msg_id,
+ msg_sender_id=msg_sender_id,
+ fb_sender_id=fb_sender_id,
+ fb_type=fb_type,
+ content=content,
+ ))
+
+ def get_feedback(self, room_id=None, msg_id=None, msg_sender_id=None,
+ fb_sender_id=None, fb_type=None):
+ query = FeedbackTable.select_statement(
+ "msg_sender_id = ? AND room_id = ? AND msg_id = ? " +
+ "AND fb_sender_id = ? AND feedback_type = ? " +
+ "ORDER BY id DESC LIMIT 1")
+ return self._execute(
+ FeedbackTable.decode_single_result,
+ query, msg_sender_id, room_id, msg_id, fb_sender_id, fb_type,
+ )
+
+ def get_max_feedback_id(self):
+ return self._simple_max_id(FeedbackTable.table_name)
+
+
+class FeedbackTable(Table):
+ table_name = "feedback"
+
+ fields = [
+ "id",
+ "content",
+ "feedback_type",
+ "fb_sender_id",
+ "msg_id",
+ "room_id",
+ "msg_sender_id"
+ ]
+
+ class EntryType(collections.namedtuple("FeedbackEntry", fields)):
+
+ def as_event(self, event_factory):
+ return event_factory.create_event(
+ etype=FeedbackEvent.TYPE,
+ room_id=self.room_id,
+ msg_id=self.msg_id,
+ msg_sender_id=self.msg_sender_id,
+ user_id=self.fb_sender_id,
+ feedback_type=self.feedback_type,
+ content=json.loads(self.content),
+ )
diff --git a/synapse/storage/message.py b/synapse/storage/message.py
new file mode 100644
index 0000000000..4822fa709d
--- /dev/null
+++ b/synapse/storage/message.py
@@ -0,0 +1,80 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table
+from synapse.api.events.room import MessageEvent
+
+import collections
+import json
+
+
+class MessageStore(SQLBaseStore):
+
+ def get_message(self, user_id, room_id, msg_id):
+ """Get a message from the store.
+
+ Args:
+ user_id (str): The ID of the user who sent the message.
+ room_id (str): The room the message was sent in.
+ msg_id (str): The unique ID for this user/room combo.
+ """
+ query = MessagesTable.select_statement(
+ "user_id = ? AND room_id = ? AND msg_id = ? " +
+ "ORDER BY id DESC LIMIT 1")
+ return self._execute(
+ MessagesTable.decode_single_result,
+ query, user_id, room_id, msg_id,
+ )
+
+ def store_message(self, user_id, room_id, msg_id, content):
+ """Store a message in the store.
+
+ Args:
+ user_id (str): The ID of the user who sent the message.
+ room_id (str): The room the message was sent in.
+ msg_id (str): The unique ID for this user/room combo.
+ content (str): The content of the message (JSON)
+ """
+ return self._simple_insert(MessagesTable.table_name, dict(
+ user_id=user_id,
+ room_id=room_id,
+ msg_id=msg_id,
+ content=content,
+ ))
+
+ def get_max_message_id(self):
+ return self._simple_max_id(MessagesTable.table_name)
+
+
+class MessagesTable(Table):
+ table_name = "messages"
+
+ fields = [
+ "id",
+ "user_id",
+ "room_id",
+ "msg_id",
+ "content"
+ ]
+
+ class EntryType(collections.namedtuple("MessageEntry", fields)):
+
+ def as_event(self, event_factory):
+ return event_factory.create_event(
+ etype=MessageEvent.TYPE,
+ room_id=self.room_id,
+ user_id=self.user_id,
+ msg_id=self.msg_id,
+ content=json.loads(self.content),
+ )
diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py
new file mode 100644
index 0000000000..a1cdde0a3b
--- /dev/null
+++ b/synapse/storage/pdu.py
@@ -0,0 +1,993 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table, JoinHelper
+
+from synapse.util.logutils import log_function
+
+from collections import namedtuple
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class PduStore(SQLBaseStore):
+ """A collection of queries for handling PDUs.
+ """
+
+ def get_pdu(self, pdu_id, origin):
+ """Given a pdu_id and origin, get a PDU.
+
+ Args:
+ txn
+ pdu_id (str)
+ origin (str)
+
+ Returns:
+ PduTuple: If the pdu does not exist in the database, returns None
+ """
+
+ return self._db_pool.runInteraction(
+ self._get_pdu_tuple, pdu_id, origin
+ )
+
+ def _get_pdu_tuple(self, txn, pdu_id, origin):
+ res = self._get_pdu_tuples(txn, [(pdu_id, origin)])
+ return res[0] if res else None
+
+ def _get_pdu_tuples(self, txn, pdu_id_tuples):
+ results = []
+ for pdu_id, origin in pdu_id_tuples:
+ txn.execute(
+ PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"),
+ (pdu_id, origin)
+ )
+
+ edges = [
+ (r.prev_pdu_id, r.prev_origin)
+ for r in PduEdgesTable.decode_results(txn.fetchall())
+ ]
+
+ query = (
+ "SELECT %(fields)s FROM %(pdus)s as p "
+ "LEFT JOIN %(state)s as s "
+ "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
+ "WHERE p.pdu_id = ? AND p.origin = ? "
+ ) % {
+ "fields": _pdu_state_joiner.get_fields(
+ PdusTable="p", StatePdusTable="s"),
+ "pdus": PdusTable.table_name,
+ "state": StatePdusTable.table_name,
+ }
+
+ txn.execute(query, (pdu_id, origin))
+
+ row = txn.fetchone()
+ if row:
+ results.append(PduTuple(PduEntry(*row), edges))
+
+ return results
+
+ def get_current_state_for_context(self, context):
+ """Get a list of PDUs that represent the current state for a given
+ context
+
+ Args:
+ context (str)
+
+ Returns:
+ list: A list of PduTuples
+ """
+
+ return self._db_pool.runInteraction(
+ self._get_current_state_for_context,
+ context
+ )
+
+ def _get_current_state_for_context(self, txn, context):
+ query = (
+ "SELECT pdu_id, origin FROM %s WHERE context = ?"
+ % CurrentStateTable.table_name
+ )
+
+ logger.debug("get_current_state %s, Args=%s", query, context)
+ txn.execute(query, (context,))
+
+ res = txn.fetchall()
+
+ logger.debug("get_current_state %d results", len(res))
+
+ return self._get_pdu_tuples(txn, res)
+
+ def persist_pdu(self, prev_pdus, **cols):
+ """Inserts a (non-state) PDU into the database.
+
+ Args:
+ txn,
+ prev_pdus (list)
+ **cols: The columns to insert into the PdusTable.
+ """
+ return self._db_pool.runInteraction(
+ self._persist_pdu, prev_pdus, cols
+ )
+
+ def _persist_pdu(self, txn, prev_pdus, cols):
+ entry = PdusTable.EntryType(
+ **{k: cols.get(k, None) for k in PdusTable.fields}
+ )
+
+ txn.execute(PdusTable.insert_statement(), entry)
+
+ self._handle_prev_pdus(
+ txn, entry.outlier, entry.pdu_id, entry.origin,
+ prev_pdus, entry.context
+ )
+
+ def mark_pdu_as_processed(self, pdu_id, pdu_origin):
+ """Mark a received PDU as processed.
+
+ Args:
+ txn
+ pdu_id (str)
+ pdu_origin (str)
+ """
+
+ return self._db_pool.runInteraction(
+ self._mark_as_processed, pdu_id, pdu_origin
+ )
+
+ def _mark_as_processed(self, txn, pdu_id, pdu_origin):
+ txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name)
+
+ def get_all_pdus_from_context(self, context):
+ """Get a list of all PDUs for a given context."""
+ return self._db_pool.runInteraction(
+ self._get_all_pdus_from_context, context,
+ )
+
+ def _get_all_pdus_from_context(self, txn, context):
+ query = (
+ "SELECT pdu_id, origin FROM %s "
+ "WHERE context = ?"
+ ) % PdusTable.table_name
+
+ txn.execute(query, (context,))
+
+ return self._get_pdu_tuples(txn, txn.fetchall())
+
+ def get_pagination(self, context, pdu_list, limit):
+ """Get a list of Pdus for a given topic that occured before (and
+ including) the pdus in pdu_list. Return a list of max size `limit`.
+
+ Args:
+ txn
+ context (str)
+ pdu_list (list)
+ limit (int)
+
+ Return:
+ list: A list of PduTuples
+ """
+ return self._db_pool.runInteraction(
+ self._get_paginate, context, pdu_list, limit
+ )
+
+ def _get_paginate(self, txn, context, pdu_list, limit):
+ logger.debug(
+ "paginate: %s, %s, %s",
+ context, repr(pdu_list), limit
+ )
+
+ # We seed the pdu_results with the things from the pdu_list.
+ pdu_results = pdu_list
+
+ front = pdu_list
+
+ query = (
+ "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s "
+ "WHERE context = ? AND pdu_id = ? AND origin = ? "
+ "LIMIT ?"
+ ) % {
+ "edges_table": PduEdgesTable.table_name,
+ }
+
+ # We iterate through all pdu_ids in `front` to select their previous
+ # pdus. These are dumped in `new_front`. We continue until we reach the
+ # limit *or* new_front is empty (i.e., we've run out of things to
+ # select
+ while front and len(pdu_results) < limit:
+
+ new_front = []
+ for pdu_id, origin in front:
+ logger.debug(
+ "_paginate_interaction: i=%s, o=%s",
+ pdu_id, origin
+ )
+
+ txn.execute(
+ query,
+ (context, pdu_id, origin, limit - len(pdu_results))
+ )
+
+ for row in txn.fetchall():
+ logger.debug(
+ "_paginate_interaction: got i=%s, o=%s",
+ *row
+ )
+ new_front.append(row)
+
+ front = new_front
+ pdu_results += new_front
+
+ # We also want to update the `prev_pdus` attributes before returning.
+ return self._get_pdu_tuples(txn, pdu_results)
+
+ def get_min_depth_for_context(self, context):
+ """Get the current minimum depth for a context
+
+ Args:
+ txn
+ context (str)
+ """
+ return self._db_pool.runInteraction(
+ self._get_min_depth_for_context, context
+ )
+
+ def _get_min_depth_for_context(self, txn, context):
+ return self._get_min_depth_interaction(txn, context)
+
+ def _get_min_depth_interaction(self, txn, context):
+ txn.execute(
+ "SELECT min_depth FROM %s WHERE context = ?"
+ % ContextDepthTable.table_name,
+ (context,)
+ )
+
+ row = txn.fetchone()
+
+ return row[0] if row else None
+
+ def update_min_depth_for_context(self, context, depth):
+ """Update the minimum `depth` of the given context, which is the line
+ where we stop paginating backwards on.
+
+ Args:
+ context (str)
+ depth (int)
+ """
+ return self._db_pool.runInteraction(
+ self._update_min_depth_for_context, context, depth
+ )
+
+ def _update_min_depth_for_context(self, txn, context, depth):
+ min_depth = self._get_min_depth_interaction(txn, context)
+
+ do_insert = depth < min_depth if min_depth else True
+
+ if do_insert:
+ txn.execute(
+ "INSERT OR REPLACE INTO %s (context, min_depth) "
+ "VALUES (?,?)" % ContextDepthTable.table_name,
+ (context, depth)
+ )
+
+ def get_latest_pdus_in_context(self, context):
+ """Get's a list of the most current pdus for a given context. This is
+ used when we are sending a Pdu and need to fill out the `prev_pdus`
+ key
+
+ Args:
+ txn
+ context
+ """
+ return self._db_pool.runInteraction(
+ self._get_latest_pdus_in_context, context
+ )
+
+ def _get_latest_pdus_in_context(self, txn, context):
+ query = (
+ "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p "
+ "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id "
+ "AND f.origin = p.origin "
+ "WHERE f.context = ?"
+ ) % {
+ "pdus": PdusTable.table_name,
+ "forward": PduForwardExtremitiesTable.table_name,
+ }
+
+ logger.debug("get_prev query: %s", query)
+
+ txn.execute(
+ query,
+ (context, )
+ )
+
+ results = txn.fetchall()
+
+ return [(row[0], row[1], row[2]) for row in results]
+
+ def get_oldest_pdus_in_context(self, context):
+ """Get a list of Pdus that we paginated beyond yet (and haven't seen).
+ This list is used when we want to paginate backwards and is the list we
+ send to the remote server.
+
+ Args:
+ txn
+ context (str)
+
+ Returns:
+ list: A list of PduIdTuple.
+ """
+ return self._db_pool.runInteraction(
+ self._get_oldest_pdus_in_context, context
+ )
+
+ def _get_oldest_pdus_in_context(self, txn, context):
+ txn.execute(
+ "SELECT pdu_id, origin FROM %(back)s WHERE context = ?"
+ % {"back": PduBackwardExtremitiesTable.table_name, },
+ (context,)
+ )
+ return [PduIdTuple(i, o) for i, o in txn.fetchall()]
+
+ def is_pdu_new(self, pdu_id, origin, context, depth):
+ """For a given Pdu, try and figure out if it's 'new', i.e., if it's
+ not something we got randomly from the past, for example when we
+ request the current state of the room that will probably return a bunch
+ of pdus from before we joined.
+
+ Args:
+ txn
+ pdu_id (str)
+ origin (str)
+ context (str)
+ depth (int)
+
+ Returns:
+ bool
+ """
+
+ return self._db_pool.runInteraction(
+ self._is_pdu_new,
+ pdu_id=pdu_id,
+ origin=origin,
+ context=context,
+ depth=depth
+ )
+
+ def _is_pdu_new(self, txn, pdu_id, origin, context, depth):
+ # If depth > min depth in back table, then we classify it as new.
+ # OR if there is nothing in the back table, then it kinda needs to
+ # be a new thing.
+ query = (
+ "SELECT min(p.depth) FROM %(edges)s as e "
+ "INNER JOIN %(back)s as b "
+ "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin "
+ "INNER JOIN %(pdus)s as p "
+ "ON e.pdu_id = p.pdu_id AND p.origin = e.origin "
+ "WHERE p.context = ?"
+ ) % {
+ "pdus": PdusTable.table_name,
+ "edges": PduEdgesTable.table_name,
+ "back": PduBackwardExtremitiesTable.table_name,
+ }
+
+ txn.execute(query, (context,))
+
+ min_depth, = txn.fetchone()
+
+ if not min_depth or depth > int(min_depth):
+ logger.debug(
+ "is_new true: id=%s, o=%s, d=%s min_depth=%s",
+ pdu_id, origin, depth, min_depth
+ )
+ return True
+
+ # If this pdu is in the forwards table, then it also is a new one
+ query = (
+ "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?"
+ ) % {
+ "forward": PduForwardExtremitiesTable.table_name,
+ }
+
+ txn.execute(query, (pdu_id, origin))
+
+ # Did we get anything?
+ if txn.fetchall():
+ logger.debug(
+ "is_new true: id=%s, o=%s, d=%s was forward",
+ pdu_id, origin, depth
+ )
+ return True
+
+ logger.debug(
+ "is_new false: id=%s, o=%s, d=%s",
+ pdu_id, origin, depth
+ )
+
+ # FINE THEN. It's probably old.
+ return False
+
+ @staticmethod
+ @log_function
+ def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus,
+ context):
+ txn.executemany(
+ PduEdgesTable.insert_statement(),
+ [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus]
+ )
+
+ # Update the extremities table if this is not an outlier.
+ if not outlier:
+
+ # First, we delete the new one from the forwards extremities table.
+ query = (
+ "DELETE FROM %s WHERE pdu_id = ? AND origin = ?"
+ % PduForwardExtremitiesTable.table_name
+ )
+ txn.executemany(query, prev_pdus)
+
+ # We only insert as a forward extremety the new pdu if there are no
+ # other pdus that reference it as a prev pdu
+ query = (
+ "INSERT INTO %(table)s (pdu_id, origin, context) "
+ "SELECT ?, ?, ? WHERE NOT EXISTS ("
+ "SELECT 1 FROM %(pdu_edges)s WHERE "
+ "prev_pdu_id = ? AND prev_origin = ?"
+ ")"
+ ) % {
+ "table": PduForwardExtremitiesTable.table_name,
+ "pdu_edges": PduEdgesTable.table_name
+ }
+
+ logger.debug("query: %s", query)
+
+ txn.execute(query, (pdu_id, origin, context, pdu_id, origin))
+
+ # Insert all the prev_pdus as a backwards thing, they'll get
+ # deleted in a second if they're incorrect anyway.
+ txn.executemany(
+ PduBackwardExtremitiesTable.insert_statement(),
+ [(i, o, context) for i, o in prev_pdus]
+ )
+
+ # Also delete from the backwards extremities table all ones that
+ # reference pdus that we have already seen
+ query = (
+ "DELETE FROM %(pdu_back)s WHERE EXISTS ("
+ "SELECT 1 FROM %(pdus)s AS pdus "
+ "WHERE "
+ "%(pdu_back)s.pdu_id = pdus.pdu_id "
+ "AND %(pdu_back)s.origin = pdus.origin "
+ "AND not pdus.outlier "
+ ")"
+ ) % {
+ "pdu_back": PduBackwardExtremitiesTable.table_name,
+ "pdus": PdusTable.table_name,
+ }
+ txn.execute(query)
+
+
+class StatePduStore(SQLBaseStore):
+ """A collection of queries for handling state PDUs.
+ """
+
+ def persist_state(self, prev_pdus, **cols):
+ """Inserts a state PDU into the database
+
+ Args:
+ txn,
+ prev_pdus (list)
+ **cols: The columns to insert into the PdusTable and StatePdusTable
+ """
+
+ return self._db_pool.runInteraction(
+ self._persist_state, prev_pdus, cols
+ )
+
+ def _persist_state(self, txn, prev_pdus, cols):
+ pdu_entry = PdusTable.EntryType(
+ **{k: cols.get(k, None) for k in PdusTable.fields}
+ )
+ state_entry = StatePdusTable.EntryType(
+ **{k: cols.get(k, None) for k in StatePdusTable.fields}
+ )
+
+ logger.debug("Inserting pdu: %s", repr(pdu_entry))
+ logger.debug("Inserting state: %s", repr(state_entry))
+
+ txn.execute(PdusTable.insert_statement(), pdu_entry)
+ txn.execute(StatePdusTable.insert_statement(), state_entry)
+
+ self._handle_prev_pdus(
+ txn,
+ pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus,
+ pdu_entry.context
+ )
+
+ def get_unresolved_state_tree(self, new_state_pdu):
+ return self._db_pool.runInteraction(
+ self._get_unresolved_state_tree, new_state_pdu
+ )
+
+ @log_function
+ def _get_unresolved_state_tree(self, txn, new_pdu):
+ current = self._get_current_interaction(
+ txn,
+ new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
+ )
+
+ ReturnType = namedtuple(
+ "StateReturnType", ["new_branch", "current_branch"]
+ )
+ return_value = ReturnType([new_pdu], [])
+
+ if not current:
+ logger.debug("get_unresolved_state_tree No current state.")
+ return return_value
+
+ return_value.current_branch.append(current)
+
+ enum_branches = self._enumerate_state_branches(
+ txn, new_pdu, current
+ )
+
+ for branch, prev_state, state in enum_branches:
+ if state:
+ return_value[branch].append(state)
+ else:
+ break
+
+ return return_value
+
+ def update_current_state(self, pdu_id, origin, context, pdu_type,
+ state_key):
+ return self._db_pool.runInteraction(
+ self._update_current_state,
+ pdu_id, origin, context, pdu_type, state_key
+ )
+
+ def _update_current_state(self, txn, pdu_id, origin, context, pdu_type,
+ state_key):
+ query = (
+ "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
+ ) % {
+ "curr": CurrentStateTable.table_name,
+ "fields": CurrentStateTable.get_fields_string(),
+ "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
+ }
+
+ query_args = CurrentStateTable.EntryType(
+ pdu_id=pdu_id,
+ origin=origin,
+ context=context,
+ pdu_type=pdu_type,
+ state_key=state_key
+ )
+
+ txn.execute(query, query_args)
+
+ def get_current_state(self, context, pdu_type, state_key):
+ """For a given context, pdu_type, state_key 3-tuple, return what is
+ currently considered the current state.
+
+ Args:
+ txn
+ context (str)
+ pdu_type (str)
+ state_key (str)
+
+ Returns:
+ PduEntry
+ """
+
+ return self._db_pool.runInteraction(
+ self._get_current_state, context, pdu_type, state_key
+ )
+
+ def _get_current_state(self, txn, context, pdu_type, state_key):
+ return self._get_current_interaction(txn, context, pdu_type, state_key)
+
+ def _get_current_interaction(self, txn, context, pdu_type, state_key):
+ logger.debug(
+ "_get_current_interaction %s %s %s",
+ context, pdu_type, state_key
+ )
+
+ fields = _pdu_state_joiner.get_fields(
+ PdusTable="p", StatePdusTable="s")
+
+ current_query = (
+ "SELECT %(fields)s FROM %(state)s as s "
+ "INNER JOIN %(pdus)s as p "
+ "ON s.pdu_id = p.pdu_id AND s.origin = p.origin "
+ "INNER JOIN %(curr)s as c "
+ "ON s.pdu_id = c.pdu_id AND s.origin = c.origin "
+ "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? "
+ ) % {
+ "fields": fields,
+ "curr": CurrentStateTable.table_name,
+ "state": StatePdusTable.table_name,
+ "pdus": PdusTable.table_name,
+ }
+
+ txn.execute(
+ current_query,
+ (context, pdu_type, state_key)
+ )
+
+ row = txn.fetchone()
+
+ result = PduEntry(*row) if row else None
+
+ if not result:
+ logger.debug("_get_current_interaction not found")
+ else:
+ logger.debug(
+ "_get_current_interaction found %s %s",
+ result.pdu_id, result.origin
+ )
+
+ return result
+
+ def get_next_missing_pdu(self, new_pdu):
+ """When we get a new state pdu we need to check whether we need to do
+ any conflict resolution, if we do then we need to check if we need
+ to go back and request some more state pdus that we haven't seen yet.
+
+ Args:
+ txn
+ new_pdu
+
+ Returns:
+ PduIdTuple: A pdu that we are missing, or None if we have all the
+ pdus required to do the conflict resolution.
+ """
+ return self._db_pool.runInteraction(
+ self._get_next_missing_pdu, new_pdu
+ )
+
+ def _get_next_missing_pdu(self, txn, new_pdu):
+ logger.debug(
+ "get_next_missing_pdu %s %s",
+ new_pdu.pdu_id, new_pdu.origin
+ )
+
+ current = self._get_current_interaction(
+ txn,
+ new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
+ )
+
+ if (not current or not current.prev_state_id
+ or not current.prev_state_origin):
+ return None
+
+ # Oh look, it's a straight clobber, so wooooo almost no-op.
+ if (new_pdu.prev_state_id == current.pdu_id
+ and new_pdu.prev_state_origin == current.origin):
+ return None
+
+ enum_branches = self._enumerate_state_branches(txn, new_pdu, current)
+ for branch, prev_state, state in enum_branches:
+ if not state:
+ return PduIdTuple(
+ prev_state.prev_state_id,
+ prev_state.prev_state_origin
+ )
+
+ return None
+
+ def handle_new_state(self, new_pdu):
+ """Actually perform conflict resolution on the new_pdu on the
+ assumption we have all the pdus required to perform it.
+
+ Args:
+ new_pdu
+
+ Returns:
+ bool: True if the new_pdu clobbered the current state, False if not
+ """
+ return self._db_pool.runInteraction(
+ self._handle_new_state, new_pdu
+ )
+
+ def _handle_new_state(self, txn, new_pdu):
+ logger.debug(
+ "handle_new_state %s %s",
+ new_pdu.pdu_id, new_pdu.origin
+ )
+
+ current = self._get_current_interaction(
+ txn,
+ new_pdu.context, new_pdu.pdu_type, new_pdu.state_key
+ )
+
+ is_current = False
+
+ if (not current or not current.prev_state_id
+ or not current.prev_state_origin):
+ # Oh, we don't have any state for this yet.
+ is_current = True
+ elif (current.pdu_id == new_pdu.prev_state_id
+ and current.origin == new_pdu.prev_state_origin):
+ # Oh! A direct clobber. Just do it.
+ is_current = True
+ else:
+ ##
+ # Ok, now loop through until we get to a common ancestor.
+ max_new = int(new_pdu.power_level)
+ max_current = int(current.power_level)
+
+ enum_branches = self._enumerate_state_branches(
+ txn, new_pdu, current
+ )
+ for branch, prev_state, state in enum_branches:
+ if not state:
+ raise RuntimeError(
+ "Could not find state_pdu %s %s" %
+ (
+ prev_state.prev_state_id,
+ prev_state.prev_state_origin
+ )
+ )
+
+ if branch == 0:
+ max_new = max(int(state.depth), max_new)
+ else:
+ max_current = max(int(state.depth), max_current)
+
+ is_current = max_new > max_current
+
+ if is_current:
+ logger.debug("handle_new_state make current")
+
+ # Right, this is a new thing, so woo, just insert it.
+ txn.execute(
+ "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)"
+ % {
+ "curr": CurrentStateTable.table_name,
+ "fields": CurrentStateTable.get_fields_string(),
+ "qs": ", ".join(["?"] * len(CurrentStateTable.fields))
+ },
+ CurrentStateTable.EntryType(
+ *(new_pdu.__dict__[k] for k in CurrentStateTable.fields)
+ )
+ )
+ else:
+ logger.debug("handle_new_state not current")
+
+ logger.debug("handle_new_state done")
+
+ return is_current
+
+ @classmethod
+ @log_function
+ def _enumerate_state_branches(cls, txn, pdu_a, pdu_b):
+ branch_a = pdu_a
+ branch_b = pdu_b
+
+ get_query = (
+ "SELECT %(fields)s FROM %(pdus)s as p "
+ "LEFT JOIN %(state)s as s "
+ "ON p.pdu_id = s.pdu_id AND p.origin = s.origin "
+ "WHERE p.pdu_id = ? AND p.origin = ? "
+ ) % {
+ "fields": _pdu_state_joiner.get_fields(
+ PdusTable="p", StatePdusTable="s"),
+ "pdus": PdusTable.table_name,
+ "state": StatePdusTable.table_name,
+ }
+
+ while True:
+ if (branch_a.pdu_id == branch_b.pdu_id
+ and branch_a.origin == branch_b.origin):
+ # Woo! We found a common ancestor
+ logger.debug("_enumerate_state_branches Found common ancestor")
+ break
+
+ do_branch_a = (
+ hasattr(branch_a, "prev_state_id") and
+ branch_a.prev_state_id
+ )
+
+ do_branch_b = (
+ hasattr(branch_b, "prev_state_id") and
+ branch_b.prev_state_id
+ )
+
+ logger.debug(
+ "do_branch_a=%s, do_branch_b=%s",
+ do_branch_a, do_branch_b
+ )
+
+ if do_branch_a and do_branch_b:
+ do_branch_a = int(branch_a.depth) > int(branch_b.depth)
+
+ if do_branch_a:
+ pdu_tuple = PduIdTuple(
+ branch_a.prev_state_id,
+ branch_a.prev_state_origin
+ )
+
+ logger.debug("getting branch_a prev %s", pdu_tuple)
+ txn.execute(get_query, pdu_tuple)
+
+ prev_branch = branch_a
+
+ res = txn.fetchone()
+ branch_a = PduEntry(*res) if res else None
+
+ logger.debug("branch_a=%s", branch_a)
+
+ yield (0, prev_branch, branch_a)
+
+ if not branch_a:
+ break
+ elif do_branch_b:
+ pdu_tuple = PduIdTuple(
+ branch_b.prev_state_id,
+ branch_b.prev_state_origin
+ )
+ txn.execute(get_query, pdu_tuple)
+
+ logger.debug("getting branch_b prev %s", pdu_tuple)
+
+ prev_branch = branch_b
+
+ res = txn.fetchone()
+ branch_b = PduEntry(*res) if res else None
+
+ logger.debug("branch_b=%s", branch_b)
+
+ yield (1, prev_branch, branch_b)
+
+ if not branch_b:
+ break
+ else:
+ break
+
+
+class PdusTable(Table):
+ table_name = "pdus"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "context",
+ "pdu_type",
+ "ts",
+ "depth",
+ "is_state",
+ "content_json",
+ "unrecognized_keys",
+ "outlier",
+ "have_processed",
+ ]
+
+ EntryType = namedtuple("PdusEntry", fields)
+
+
+class PduDestinationsTable(Table):
+ table_name = "pdu_destinations"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "destination",
+ "delivered_ts",
+ ]
+
+ EntryType = namedtuple("PduDestinationsEntry", fields)
+
+
+class PduEdgesTable(Table):
+ table_name = "pdu_edges"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "prev_pdu_id",
+ "prev_origin",
+ "context"
+ ]
+
+ EntryType = namedtuple("PduEdgesEntry", fields)
+
+
+class PduForwardExtremitiesTable(Table):
+ table_name = "pdu_forward_extremities"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "context",
+ ]
+
+ EntryType = namedtuple("PduForwardExtremitiesEntry", fields)
+
+
+class PduBackwardExtremitiesTable(Table):
+ table_name = "pdu_backward_extremities"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "context",
+ ]
+
+ EntryType = namedtuple("PduBackwardExtremitiesEntry", fields)
+
+
+class ContextDepthTable(Table):
+ table_name = "context_depth"
+
+ fields = [
+ "context",
+ "min_depth",
+ ]
+
+ EntryType = namedtuple("ContextDepthEntry", fields)
+
+
+class StatePdusTable(Table):
+ table_name = "state_pdus"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "context",
+ "pdu_type",
+ "state_key",
+ "power_level",
+ "prev_state_id",
+ "prev_state_origin",
+ ]
+
+ EntryType = namedtuple("StatePdusEntry", fields)
+
+
+class CurrentStateTable(Table):
+ table_name = "current_state"
+
+ fields = [
+ "pdu_id",
+ "origin",
+ "context",
+ "pdu_type",
+ "state_key",
+ ]
+
+ EntryType = namedtuple("CurrentStateEntry", fields)
+
+_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable)
+
+
+# TODO: These should probably be put somewhere more sensible
+PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin"))
+
+PduEntry = _pdu_state_joiner.EntryType
+""" We are always interested in the join of the PdusTable and StatePdusTable,
+rather than just the PdusTable.
+
+This does not include a prev_pdus key.
+"""
+
+PduTuple = namedtuple(
+ "PduTuple",
+ ("pdu_entry", "prev_pdu_list")
+)
+""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent
+the `prev_pdus` key of a PDU.
+"""
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
new file mode 100644
index 0000000000..e57ddaf149
--- /dev/null
+++ b/synapse/storage/presence.py
@@ -0,0 +1,103 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+
+
+class PresenceStore(SQLBaseStore):
+ def create_presence(self, user_localpart):
+ return self._simple_insert(
+ table="presence",
+ values={"user_id": user_localpart},
+ )
+
+ def has_presence_state(self, user_localpart):
+ return self._simple_select_one(
+ table="presence",
+ keyvalues={"user_id": user_localpart},
+ retcols=["user_id"],
+ allow_none=True,
+ )
+
+ def get_presence_state(self, user_localpart):
+ return self._simple_select_one(
+ table="presence",
+ keyvalues={"user_id": user_localpart},
+ retcols=["state", "status_msg"],
+ )
+
+ def set_presence_state(self, user_localpart, new_state):
+ return self._simple_update_one(
+ table="presence",
+ keyvalues={"user_id": user_localpart},
+ updatevalues={"state": new_state["state"],
+ "status_msg": new_state["status_msg"]},
+ retcols=["state"],
+ )
+
+ def allow_presence_visible(self, observed_localpart, observer_userid):
+ return self._simple_insert(
+ table="presence_allow_inbound",
+ values={"observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid},
+ )
+
+ def disallow_presence_visible(self, observed_localpart, observer_userid):
+ return self._simple_delete_one(
+ table="presence_allow_inbound",
+ keyvalues={"observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid},
+ )
+
+ def is_presence_visible(self, observed_localpart, observer_userid):
+ return self._simple_select_one(
+ table="presence_allow_inbound",
+ keyvalues={"observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid},
+ allow_none=True,
+ )
+
+ def add_presence_list_pending(self, observer_localpart, observed_userid):
+ return self._simple_insert(
+ table="presence_list",
+ values={"user_id": observer_localpart,
+ "observed_user_id": observed_userid,
+ "accepted": False},
+ )
+
+ def set_presence_list_accepted(self, observer_localpart, observed_userid):
+ return self._simple_update_one(
+ table="presence_list",
+ keyvalues={"user_id": observer_localpart,
+ "observed_user_id": observed_userid},
+ updatevalues={"accepted": True},
+ )
+
+ def get_presence_list(self, observer_localpart, accepted=None):
+ keyvalues = {"user_id": observer_localpart}
+ if accepted is not None:
+ keyvalues["accepted"] = accepted
+
+ return self._simple_select_list(
+ table="presence_list",
+ keyvalues=keyvalues,
+ retcols=["observed_user_id", "accepted"],
+ )
+
+ def del_presence_list(self, observer_localpart, observed_userid):
+ return self._simple_delete_one(
+ table="presence_list",
+ keyvalues={"user_id": observer_localpart,
+ "observed_user_id": observed_userid},
+ )
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
new file mode 100644
index 0000000000..d2f24930c1
--- /dev/null
+++ b/synapse/storage/profile.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore
+
+
+class ProfileStore(SQLBaseStore):
+ def create_profile(self, user_localpart):
+ return self._simple_insert(
+ table="profiles",
+ values={"user_id": user_localpart},
+ )
+
+ def get_profile_displayname(self, user_localpart):
+ return self._simple_select_one_onecol(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ retcol="displayname",
+ )
+
+ def set_profile_displayname(self, user_localpart, new_displayname):
+ return self._simple_update_one(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ updatevalues={"displayname": new_displayname},
+ )
+
+ def get_profile_avatar_url(self, user_localpart):
+ return self._simple_select_one_onecol(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ retcol="avatar_url",
+ )
+
+ def set_profile_avatar_url(self, user_localpart, new_avatar_url):
+ return self._simple_update_one(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ updatevalues={"avatar_url": new_avatar_url},
+ )
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
new file mode 100644
index 0000000000..4a970dd546
--- /dev/null
+++ b/synapse/storage/registration.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from sqlite3 import IntegrityError
+
+from synapse.api.errors import StoreError
+
+from ._base import SQLBaseStore
+
+
+class RegistrationStore(SQLBaseStore):
+
+ def __init__(self, hs):
+ super(RegistrationStore, self).__init__(hs)
+
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def add_access_token_to_user(self, user_id, token):
+ """Adds an access token for the given user.
+
+ Args:
+ user_id (str): The user ID.
+ token (str): The new access token to add.
+ Raises:
+ StoreError if there was a problem adding this.
+ """
+ row = yield self._simple_select_one("users", {"name": user_id}, ["id"])
+ if not row:
+ raise StoreError(400, "Bad user ID supplied.")
+ row_id = row["id"]
+ yield self._simple_insert(
+ "access_tokens",
+ {
+ "user_id": row_id,
+ "token": token
+ }
+ )
+
+ @defer.inlineCallbacks
+ def register(self, user_id, token, password_hash):
+ """Attempts to register an account.
+
+ Args:
+ user_id (str): The desired user ID to register.
+ token (str): The desired access token to use for this user.
+ password_hash (str): Optional. The password hash for this user.
+ Raises:
+ StoreError if the user_id could not be registered.
+ """
+ yield self._db_pool.runInteraction(self._register, user_id, token,
+ password_hash)
+
+ def _register(self, txn, user_id, token, password_hash):
+ now = int(self.clock.time())
+
+ try:
+ txn.execute("INSERT INTO users(name, password_hash, creation_ts) "
+ "VALUES (?,?,?)",
+ [user_id, password_hash, now])
+ except IntegrityError:
+ raise StoreError(400, "User ID already taken.")
+
+ # it's possible for this to get a conflict, but only for a single user
+ # since tokens are namespaced based on their user ID
+ txn.execute("INSERT INTO access_tokens(user_id, token) " +
+ "VALUES (?,?)", [txn.lastrowid, token])
+
+ def get_user_by_id(self, user_id):
+ query = ("SELECT users.name, users.password_hash FROM users "
+ "WHERE users.name = ?")
+ return self._execute(
+ self.cursor_to_dict,
+ query, user_id
+ )
+
+ @defer.inlineCallbacks
+ def get_user_by_token(self, token):
+ """Get a user from the given access token.
+
+ Args:
+ token (str): The access token of a user.
+ Returns:
+ str: The user ID of the user.
+ Raises:
+ StoreError if no user was found.
+ """
+ user_id = yield self._db_pool.runInteraction(self._query_for_auth,
+ token)
+ defer.returnValue(user_id)
+
+ def _query_for_auth(self, txn, token):
+ txn.execute("SELECT users.name FROM access_tokens LEFT JOIN users" +
+ " ON users.id = access_tokens.user_id WHERE token = ?",
+ [token])
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ raise StoreError(404, "Token not found.")
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
new file mode 100644
index 0000000000..174cbcf3d8
--- /dev/null
+++ b/synapse/storage/room.py
@@ -0,0 +1,129 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from sqlite3 import IntegrityError
+
+from synapse.api.errors import StoreError
+from synapse.api.events.room import RoomTopicEvent
+
+from ._base import SQLBaseStore, Table
+
+import collections
+import json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class RoomStore(SQLBaseStore):
+
+ @defer.inlineCallbacks
+ def store_room(self, room_id, room_creator_user_id, is_public):
+ """Stores a room.
+
+ Args:
+ room_id (str): The desired room ID, can be None.
+ room_creator_user_id (str): The user ID of the room creator.
+ is_public (bool): True to indicate that this room should appear in
+ public room lists.
+ Raises:
+ StoreError if the room could not be stored.
+ """
+ try:
+ yield self._simple_insert(RoomsTable.table_name, dict(
+ room_id=room_id,
+ creator=room_creator_user_id,
+ is_public=is_public
+ ))
+ except IntegrityError:
+ raise StoreError(409, "Room ID in use.")
+ except Exception as e:
+ logger.error("store_room with room_id=%s failed: %s", room_id, e)
+ raise StoreError(500, "Problem creating room.")
+
+ def store_room_config(self, room_id, visibility):
+ return self._simple_update_one(
+ table=RoomsTable.table_name,
+ keyvalues={"room_id": room_id},
+ updatevalues={"is_public": visibility}
+ )
+
+ def get_room(self, room_id):
+ """Retrieve a room.
+
+ Args:
+ room_id (str): The ID of the room to retrieve.
+ Returns:
+ A namedtuple containing the room information, or an empty list.
+ """
+ query = RoomsTable.select_statement("room_id=?")
+ return self._execute(
+ RoomsTable.decode_single_result, query, room_id,
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms(self, is_public, with_topics):
+ """Retrieve a list of all public rooms.
+
+ Args:
+ is_public (bool): True if the rooms returned should be public.
+ with_topics (bool): True to include the current topic for the room
+ in the response.
+ Returns:
+ A list of room dicts containing at least a "room_id" key, and a
+ "topic" key if one is set and with_topic=True.
+ """
+ room_data_type = RoomTopicEvent.TYPE
+ public = 1 if is_public else 0
+
+ latest_topic = ("SELECT max(room_data.id) FROM room_data WHERE "
+ + "room_data.type = ? GROUP BY room_id")
+
+ query = ("SELECT rooms.*, room_data.content FROM rooms LEFT JOIN "
+ + "room_data ON rooms.room_id = room_data.room_id WHERE "
+ + "(room_data.id IN (" + latest_topic + ") "
+ + "OR room_data.id IS NULL) AND rooms.is_public = ?")
+
+ res = yield self._execute(
+ self.cursor_to_dict, query, room_data_type, public
+ )
+
+ # return only the keys the specification expects
+ ret_keys = ["room_id", "topic"]
+
+ # extract topic from the json (icky) FIXME
+ for i, room_row in enumerate(res):
+ try:
+ content_json = json.loads(room_row["content"])
+ room_row["topic"] = content_json["topic"]
+ except:
+ pass # no topic set
+ # filter the dict based on ret_keys
+ res[i] = {k: v for k, v in room_row.iteritems() if k in ret_keys}
+
+ defer.returnValue(res)
+
+
+class RoomsTable(Table):
+ table_name = "rooms"
+
+ fields = [
+ "room_id",
+ "is_public",
+ "creator"
+ ]
+
+ EntryType = collections.namedtuple("RoomEntry", fields)
diff --git a/synapse/storage/roomdata.py b/synapse/storage/roomdata.py
new file mode 100644
index 0000000000..781d477931
--- /dev/null
+++ b/synapse/storage/roomdata.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table
+
+import collections
+import json
+
+
+class RoomDataStore(SQLBaseStore):
+
+ """Provides various CRUD operations for Room Events. """
+
+ def get_room_data(self, room_id, etype, state_key=""):
+ """Retrieve the data stored under this type and state_key.
+
+ Args:
+ room_id (str)
+ etype (str)
+ state_key (str)
+ Returns:
+ namedtuple: Or None if nothing exists at this path.
+ """
+ query = RoomDataTable.select_statement(
+ "room_id = ? AND type = ? AND state_key = ? "
+ "ORDER BY id DESC LIMIT 1"
+ )
+ return self._execute(
+ RoomDataTable.decode_single_result,
+ query, room_id, etype, state_key,
+ )
+
+ def store_room_data(self, room_id, etype, state_key="", content=None):
+ """Stores room specific data.
+
+ Args:
+ room_id (str)
+ etype (str)
+ state_key (str)
+ data (str)- The data to store for this path in JSON.
+ Returns:
+ The store ID for this data.
+ """
+ return self._simple_insert(RoomDataTable.table_name, dict(
+ etype=etype,
+ state_key=state_key,
+ room_id=room_id,
+ content=content,
+ ))
+
+ def get_max_room_data_id(self):
+ return self._simple_max_id(RoomDataTable.table_name)
+
+
+class RoomDataTable(Table):
+ table_name = "room_data"
+
+ fields = [
+ "id",
+ "room_id",
+ "type",
+ "state_key",
+ "content"
+ ]
+
+ class EntryType(collections.namedtuple("RoomDataEntry", fields)):
+
+ def as_event(self, event_factory):
+ return event_factory.create_event(
+ etype=self.type,
+ room_id=self.room_id,
+ content=json.loads(self.content),
+ )
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
new file mode 100644
index 0000000000..e6e7617797
--- /dev/null
+++ b/synapse/storage/roommember.py
@@ -0,0 +1,171 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.types import UserID
+from synapse.api.constants import Membership
+from synapse.api.events.room import RoomMemberEvent
+
+from ._base import SQLBaseStore, Table
+
+
+import collections
+import json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class RoomMemberStore(SQLBaseStore):
+
+ def get_room_member(self, user_id, room_id):
+ """Retrieve the current state of a room member.
+
+ Args:
+ user_id (str): The member's user ID.
+ room_id (str): The room the member is in.
+ Returns:
+ namedtuple: The room member from the database, or None if this
+ member does not exist.
+ """
+ query = RoomMemberTable.select_statement(
+ "room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1")
+ return self._execute(
+ RoomMemberTable.decode_single_result,
+ query, room_id, user_id,
+ )
+
+ def store_room_member(self, user_id, sender, room_id, membership, content):
+ """Store a room member in the database.
+
+ Args:
+ user_id (str): The member's user ID.
+ room_id (str): The room in relation to the member.
+ membership (synapse.api.constants.Membership): The new membership
+ state.
+ content (dict): The content of the membership (JSON).
+ """
+ content_json = json.dumps(content)
+ return self._simple_insert(RoomMemberTable.table_name, dict(
+ user_id=user_id,
+ sender=sender,
+ room_id=room_id,
+ membership=membership,
+ content=content_json,
+ ))
+
+ @defer.inlineCallbacks
+ def get_room_members(self, room_id, membership=None):
+ """Retrieve the current room member list for a room.
+
+ Args:
+ room_id (str): The room to get the list of members.
+ membership (synapse.api.constants.Membership): The filter to apply
+ to this list, or None to return all members with some state
+ associated with this room.
+ Returns:
+ list of namedtuples representing the members in this room.
+ """
+ query = RoomMemberTable.select_statement(
+ "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
+ + " WHERE room_id = ? GROUP BY user_id)"
+ )
+ res = yield self._execute(
+ RoomMemberTable.decode_results, query, room_id,
+ )
+ # strip memberships which don't match
+ if membership:
+ res = [entry for entry in res if entry.membership == membership]
+ defer.returnValue(res)
+
+ def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
+ """ Get all the rooms for this user where the membership for this user
+ matches one in the membership list.
+
+ Args:
+ user_id (str): The user ID.
+ membership_list (list): A list of synapse.api.constants.Membership
+ values which the user must be in.
+ Returns:
+ A list of dicts with "room_id" and "membership" keys.
+ """
+ if not membership_list:
+ return defer.succeed(None)
+
+ args = [user_id]
+ membership_placeholder = ["membership=?"] * len(membership_list)
+ where_membership = "(" + " OR ".join(membership_placeholder) + ")"
+ for membership in membership_list:
+ args.append(membership)
+
+ query = ("SELECT room_id, membership FROM room_memberships"
+ + " WHERE user_id=? AND " + where_membership
+ + " GROUP BY room_id ORDER BY id DESC")
+ return self._execute(
+ self.cursor_to_dict, query, *args
+ )
+
+ @defer.inlineCallbacks
+ def get_joined_hosts_for_room(self, room_id):
+ query = RoomMemberTable.select_statement(
+ "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name
+ + " WHERE room_id = ? GROUP BY user_id)"
+ )
+
+ res = yield self._execute(
+ RoomMemberTable.decode_results, query, room_id,
+ )
+
+ def host_from_user_id_string(user_id):
+ domain = UserID.from_string(entry.user_id, self.hs).domain
+ return domain
+
+ # strip memberships which don't match
+ hosts = [
+ host_from_user_id_string(entry.user_id)
+ for entry in res
+ if entry.membership == Membership.JOIN
+ ]
+
+ logger.debug("Returning hosts: %s from results: %s", hosts, res)
+
+ defer.returnValue(hosts)
+
+ def get_max_room_member_id(self):
+ return self._simple_max_id(RoomMemberTable.table_name)
+
+
+class RoomMemberTable(Table):
+ table_name = "room_memberships"
+
+ fields = [
+ "id",
+ "user_id",
+ "sender",
+ "room_id",
+ "membership",
+ "content"
+ ]
+
+ class EntryType(collections.namedtuple("RoomMemberEntry", fields)):
+
+ def as_event(self, event_factory):
+ return event_factory.create_event(
+ etype=RoomMemberEvent.TYPE,
+ room_id=self.room_id,
+ target_user_id=self.user_id,
+ user_id=self.sender,
+ content=json.loads(self.content),
+ )
diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql
new file mode 100644
index 0000000000..17b3c52f0d
--- /dev/null
+++ b/synapse/storage/schema/edge_pdus.sql
@@ -0,0 +1,31 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS context_edge_pdus(
+ id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin)
+);
+
+CREATE TABLE IF NOT EXISTS origin_edge_pdus(
+ id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this
+ pdu_id TEXT,
+ origin TEXT,
+ CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin)
+);
+
+CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin);
+CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin);
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
new file mode 100644
index 0000000000..77096546b2
--- /dev/null
+++ b/synapse/storage/schema/im.sql
@@ -0,0 +1,54 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS rooms(
+ room_id TEXT PRIMARY KEY NOT NULL,
+ is_public INTEGER,
+ creator TEXT
+);
+
+CREATE TABLE IF NOT EXISTS room_memberships(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id TEXT NOT NULL, -- no foreign key to users table, it could be an id belonging to another home server
+ sender TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ membership TEXT NOT NULL,
+ content TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS messages(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id TEXT,
+ room_id TEXT,
+ msg_id TEXT,
+ content TEXT
+);
+
+CREATE TABLE IF NOT EXISTS feedback(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ content TEXT,
+ feedback_type TEXT,
+ fb_sender_id TEXT,
+ msg_id TEXT,
+ room_id TEXT,
+ msg_sender_id TEXT
+);
+
+CREATE TABLE IF NOT EXISTS room_data(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ content TEXT
+);
diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql
new file mode 100644
index 0000000000..ca3de005e9
--- /dev/null
+++ b/synapse/storage/schema/pdu.sql
@@ -0,0 +1,106 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- Stores pdus and their content
+CREATE TABLE IF NOT EXISTS pdus(
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ pdu_type TEXT,
+ ts INTEGER,
+ depth INTEGER DEFAULT 0 NOT NULL,
+ is_state BOOL,
+ content_json TEXT,
+ unrecognized_keys TEXT,
+ outlier BOOL NOT NULL,
+ have_processed BOOL,
+ CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
+);
+
+-- Stores what the current state pdu is for a given (context, pdu_type, key) tuple
+CREATE TABLE IF NOT EXISTS state_pdus(
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ pdu_type TEXT,
+ state_key TEXT,
+ power_level TEXT,
+ prev_state_id TEXT,
+ prev_state_origin TEXT,
+ CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
+ CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin)
+);
+
+CREATE TABLE IF NOT EXISTS current_state(
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ pdu_type TEXT,
+ state_key TEXT,
+ CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin)
+ CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE
+);
+
+-- Stores where each pdu we want to send should be sent and the delivery status.
+create TABLE IF NOT EXISTS pdu_destinations(
+ pdu_id TEXT,
+ origin TEXT,
+ destination TEXT,
+ delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE
+);
+
+CREATE TABLE IF NOT EXISTS pdu_forward_extremities(
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
+);
+
+CREATE TABLE IF NOT EXISTS pdu_backward_extremities(
+ pdu_id TEXT,
+ origin TEXT,
+ context TEXT,
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE
+);
+
+CREATE TABLE IF NOT EXISTS pdu_edges(
+ pdu_id TEXT,
+ origin TEXT,
+ prev_pdu_id TEXT,
+ prev_origin TEXT,
+ context TEXT,
+ CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context)
+);
+
+CREATE TABLE IF NOT EXISTS context_depth(
+ context TEXT,
+ min_depth INTEGER,
+ CONSTRAINT uniqueness UNIQUE (context)
+);
+
+CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context);
+
+
+CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin);
+
+CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin);
+-- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination);
+
+CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context);
+CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin);
+
+CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin);
+
+CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context);
diff --git a/synapse/storage/schema/presence.sql b/synapse/storage/schema/presence.sql
new file mode 100644
index 0000000000..b22e3ba863
--- /dev/null
+++ b/synapse/storage/schema/presence.sql
@@ -0,0 +1,37 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS presence(
+ user_id INTEGER NOT NULL,
+ state INTEGER,
+ status_msg TEXT,
+ FOREIGN KEY(user_id) REFERENCES users(id)
+);
+
+-- For each of /my/ users which possibly-remote users are allowed to see their
+-- presence state
+CREATE TABLE IF NOT EXISTS presence_allow_inbound(
+ observed_user_id INTEGER NOT NULL,
+ observer_user_id TEXT, -- a UserID,
+ FOREIGN KEY(observed_user_id) REFERENCES users(id)
+);
+
+-- For each of /my/ users (watcher), which possibly-remote users are they
+-- watching?
+CREATE TABLE IF NOT EXISTS presence_list(
+ user_id INTEGER NOT NULL,
+ observed_user_id TEXT, -- a UserID,
+ accepted BOOLEAN,
+ FOREIGN KEY(user_id) REFERENCES users(id)
+);
diff --git a/synapse/storage/schema/profiles.sql b/synapse/storage/schema/profiles.sql
new file mode 100644
index 0000000000..1092d7672c
--- /dev/null
+++ b/synapse/storage/schema/profiles.sql
@@ -0,0 +1,20 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS profiles(
+ user_id INTEGER NOT NULL,
+ displayname TEXT,
+ avatar_url TEXT,
+ FOREIGN KEY(user_id) REFERENCES users(id)
+);
diff --git a/synapse/storage/schema/room_aliases.sql b/synapse/storage/schema/room_aliases.sql
new file mode 100644
index 0000000000..71a8b90e4d
--- /dev/null
+++ b/synapse/storage/schema/room_aliases.sql
@@ -0,0 +1,12 @@
+CREATE TABLE IF NOT EXISTS room_aliases(
+ room_alias TEXT NOT NULL,
+ room_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS room_alias_servers(
+ room_alias TEXT NOT NULL,
+ server TEXT NOT NULL
+);
+
+
+
diff --git a/synapse/storage/schema/transactions.sql b/synapse/storage/schema/transactions.sql
new file mode 100644
index 0000000000..4b1a2368f6
--- /dev/null
+++ b/synapse/storage/schema/transactions.sql
@@ -0,0 +1,61 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- Stores what transaction ids we have received and what our response was
+CREATE TABLE IF NOT EXISTS received_transactions(
+ transaction_id TEXT,
+ origin TEXT,
+ ts INTEGER,
+ response_code INTEGER,
+ response_json TEXT,
+ has_been_referenced BOOL default 0, -- Whether thishas been referenced by a prev_tx
+ CONSTRAINT uniquesss UNIQUE (transaction_id, origin) ON CONFLICT REPLACE
+);
+
+CREATE UNIQUE INDEX IF NOT EXISTS transactions_txid ON received_transactions(transaction_id, origin);
+CREATE INDEX IF NOT EXISTS transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0;
+
+
+-- Stores what transactions we've sent, what their response was (if we got one) and whether we have
+-- since referenced the transaction in another outgoing transaction
+CREATE TABLE IF NOT EXISTS sent_transactions(
+ id INTEGER PRIMARY KEY AUTOINCREMENT, -- This is used to apply insertion ordering
+ transaction_id TEXT,
+ destination TEXT,
+ response_code INTEGER DEFAULT 0,
+ response_json TEXT,
+ ts INTEGER
+);
+
+CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination);
+CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions(
+ destination
+);
+-- So that we can do an efficient look up of all transactions that have yet to be successfully
+-- sent.
+CREATE INDEX IF NOT EXISTS sent_transaction_sent ON sent_transactions(response_code);
+
+
+-- For sent transactions only.
+CREATE TABLE IF NOT EXISTS transaction_id_to_pdu(
+ transaction_id INTEGER,
+ destination TEXT,
+ pdu_id TEXT,
+ pdu_origin TEXT
+);
+
+CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination);
+CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination);
+CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination);
+
diff --git a/synapse/storage/schema/users.sql b/synapse/storage/schema/users.sql
new file mode 100644
index 0000000000..46b60297cb
--- /dev/null
+++ b/synapse/storage/schema/users.sql
@@ -0,0 +1,31 @@
+/* Copyright 2014 matrix.org
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS users(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ name TEXT,
+ password_hash TEXT,
+ creation_ts INTEGER,
+ UNIQUE(name) ON CONFLICT ROLLBACK
+);
+
+CREATE TABLE IF NOT EXISTS access_tokens(
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_id INTEGER NOT NULL,
+ device_id TEXT,
+ token TEXT NOT NULL,
+ last_used INTEGER,
+ FOREIGN KEY(user_id) REFERENCES users(id),
+ UNIQUE(token) ON CONFLICT ROLLBACK
+);
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
new file mode 100644
index 0000000000..c3b1bfeb32
--- /dev/null
+++ b/synapse/storage/stream.py
@@ -0,0 +1,282 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+from .message import MessagesTable
+from .feedback import FeedbackTable
+from .roomdata import RoomDataTable
+from .roommember import RoomMemberTable
+
+import json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class StreamStore(SQLBaseStore):
+
+ def get_message_stream(self, user_id, from_key, to_key, room_id, limit=0,
+ with_feedback=False):
+ """Get all messages for this user between the given keys.
+
+ Args:
+ user_id (str): The user who is requesting messages.
+ from_key (int): The ID to start returning results from (exclusive).
+ to_key (int): The ID to stop returning results (exclusive).
+ room_id (str): Gets messages only for this room. Can be None, in
+ which case all room messages will be returned.
+ Returns:
+ A tuple of rows (list of namedtuples), new_id(int)
+ """
+ if with_feedback and room_id: # with fb MUST specify a room ID
+ return self._db_pool.runInteraction(
+ self._get_message_rows_with_feedback,
+ user_id, from_key, to_key, room_id, limit
+ )
+ else:
+ return self._db_pool.runInteraction(
+ self._get_message_rows,
+ user_id, from_key, to_key, room_id, limit
+ )
+
+ def _get_message_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
+ limit):
+ # work out which rooms this user is joined in on and join them with
+ # the room id on the messages table, bounded by the specified pkeys
+
+ # get all messages where the *current* membership state is 'join' for
+ # this user in that room.
+ query = ("SELECT messages.* FROM messages WHERE ? IN"
+ + " (SELECT membership from room_memberships WHERE user_id=?"
+ + " AND room_id = messages.room_id ORDER BY id DESC LIMIT 1)")
+ query_args = ["join", user_id]
+
+ if room_id:
+ query += " AND messages.room_id=?"
+ query_args.append(room_id)
+
+ (query, query_args) = self._append_stream_operations(
+ "messages", query, query_args, from_pkey, to_pkey, limit=limit
+ )
+
+ logger.debug("[SQL] %s : %s", query, query_args)
+ cursor = txn.execute(query, query_args)
+ return self._as_events(cursor, MessagesTable, from_pkey)
+
+ def _get_message_rows_with_feedback(self, txn, user_id, from_pkey, to_pkey,
+ room_id, limit):
+ # this col represents the compressed feedback JSON as per spec
+ compressed_feedback_col = (
+ "'[' || group_concat('{\"sender_id\":\"' || f.fb_sender_id"
+ + " || '\",\"feedback_type\":\"' || f.feedback_type"
+ + " || '\",\"content\":' || f.content || '}') || ']'"
+ )
+
+ global_msg_id_join = ("f.room_id = messages.room_id"
+ + " and f.msg_id = messages.msg_id"
+ + " and messages.user_id = f.msg_sender_id")
+
+ select_query = (
+ "SELECT messages.*, f.content AS fb_content, f.fb_sender_id"
+ + ", " + compressed_feedback_col + " AS compressed_fb"
+ + " FROM messages LEFT JOIN feedback f ON " + global_msg_id_join)
+
+ current_membership_sub_query = (
+ "(SELECT membership from room_memberships rm"
+ + " WHERE user_id=? AND room_id = rm.room_id"
+ + " ORDER BY id DESC LIMIT 1)")
+
+ where = (" WHERE ? IN " + current_membership_sub_query
+ + " AND messages.room_id=?")
+
+ query = select_query + where
+ query_args = ["join", user_id, room_id]
+
+ (query, query_args) = self._append_stream_operations(
+ "messages", query, query_args, from_pkey, to_pkey,
+ limit=limit, group_by=" GROUP BY messages.id "
+ )
+
+ logger.debug("[SQL] %s : %s", query, query_args)
+ cursor = txn.execute(query, query_args)
+
+ # convert the result set into events
+ entries = self.cursor_to_dict(cursor)
+ events = []
+ for entry in entries:
+ # TODO we should spec the cursor > event mapping somewhere else.
+ event = {}
+ straight_mappings = ["msg_id", "user_id", "room_id"]
+ for key in straight_mappings:
+ event[key] = entry[key]
+ event["content"] = json.loads(entry["content"])
+ if entry["compressed_fb"]:
+ event["feedback"] = json.loads(entry["compressed_fb"])
+ events.append(event)
+
+ latest_pkey = from_pkey if len(entries) == 0 else entries[-1]["id"]
+
+ return (events, latest_pkey)
+
+ def get_room_member_stream(self, user_id, from_key, to_key):
+ """Get all room membership events for this user between the given keys.
+
+ Args:
+ user_id (str): The user who is requesting membership events.
+ from_key (int): The ID to start returning results from (exclusive).
+ to_key (int): The ID to stop returning results (exclusive).
+ Returns:
+ A tuple of rows (list of namedtuples), new_id(int)
+ """
+ return self._db_pool.runInteraction(
+ self._get_room_member_rows, user_id, from_key, to_key
+ )
+
+ def _get_room_member_rows(self, txn, user_id, from_pkey, to_pkey):
+ # get all room membership events for rooms which the user is
+ # *currently* joined in on, or all invite events for this user.
+ current_membership_sub_query = (
+ "(SELECT membership FROM room_memberships"
+ + " WHERE user_id=? AND room_id = rm.room_id"
+ + " ORDER BY id DESC LIMIT 1)")
+
+ query = ("SELECT rm.* FROM room_memberships rm "
+ # all membership events for rooms you've currently joined.
+ + " WHERE (? IN " + current_membership_sub_query
+ # all invite membership events for this user
+ + " OR rm.membership=? AND user_id=?)"
+ + " AND rm.id > ?")
+ query_args = ["join", user_id, "invite", user_id, from_pkey]
+
+ if to_pkey != -1:
+ query += " AND rm.id < ?"
+ query_args.append(to_pkey)
+
+ cursor = txn.execute(query, query_args)
+ return self._as_events(cursor, RoomMemberTable, from_pkey)
+
+ def get_feedback_stream(self, user_id, from_key, to_key, room_id, limit=0):
+ return self._db_pool.runInteraction(
+ self._get_feedback_rows,
+ user_id, from_key, to_key, room_id, limit
+ )
+
+ def _get_feedback_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
+ limit):
+ # work out which rooms this user is joined in on and join them with
+ # the room id on the feedback table, bounded by the specified pkeys
+
+ # get all messages where the *current* membership state is 'join' for
+ # this user in that room.
+ query = (
+ "SELECT feedback.* FROM feedback WHERE ? IN "
+ + "(SELECT membership from room_memberships WHERE user_id=?"
+ + " AND room_id = feedback.room_id ORDER BY id DESC LIMIT 1)")
+ query_args = ["join", user_id]
+
+ if room_id:
+ query += " AND feedback.room_id=?"
+ query_args.append(room_id)
+
+ (query, query_args) = self._append_stream_operations(
+ "feedback", query, query_args, from_pkey, to_pkey, limit=limit
+ )
+
+ logger.debug("[SQL] %s : %s", query, query_args)
+ cursor = txn.execute(query, query_args)
+ return self._as_events(cursor, FeedbackTable, from_pkey)
+
+ def get_room_data_stream(self, user_id, from_key, to_key, room_id,
+ limit=0):
+ return self._db_pool.runInteraction(
+ self._get_room_data_rows,
+ user_id, from_key, to_key, room_id, limit
+ )
+
+ def _get_room_data_rows(self, txn, user_id, from_pkey, to_pkey, room_id,
+ limit):
+ # work out which rooms this user is joined in on and join them with
+ # the room id on the feedback table, bounded by the specified pkeys
+
+ # get all messages where the *current* membership state is 'join' for
+ # this user in that room.
+ query = (
+ "SELECT room_data.* FROM room_data WHERE ? IN "
+ + "(SELECT membership from room_memberships WHERE user_id=?"
+ + " AND room_id = room_data.room_id ORDER BY id DESC LIMIT 1)")
+ query_args = ["join", user_id]
+
+ if room_id:
+ query += " AND room_data.room_id=?"
+ query_args.append(room_id)
+
+ (query, query_args) = self._append_stream_operations(
+ "room_data", query, query_args, from_pkey, to_pkey, limit=limit
+ )
+
+ logger.debug("[SQL] %s : %s", query, query_args)
+ cursor = txn.execute(query, query_args)
+ return self._as_events(cursor, RoomDataTable, from_pkey)
+
+ def _append_stream_operations(self, table_name, query, query_args,
+ from_pkey, to_pkey, limit=None,
+ group_by=""):
+ LATEST_ROW = -1
+ order_by = ""
+ if to_pkey > from_pkey:
+ if from_pkey != LATEST_ROW:
+ # e.g. from=5 to=9 >> from 5 to 9 >> id>5 AND id<9
+ query += (" AND %s.id > ? AND %s.id < ?" %
+ (table_name, table_name))
+ query_args.append(from_pkey)
+ query_args.append(to_pkey)
+ else:
+ # e.g. from=-1 to=5 >> from now to 5 >> id>5 ORDER BY id DESC
+ query += " AND %s.id > ? " % table_name
+ order_by = "ORDER BY id DESC"
+ query_args.append(to_pkey)
+ elif from_pkey > to_pkey:
+ if to_pkey != LATEST_ROW:
+ # from=9 to=5 >> from 9 to 5 >> id>5 AND id<9 ORDER BY id DESC
+ query += (" AND %s.id > ? AND %s.id < ? " %
+ (table_name, table_name))
+ order_by = "ORDER BY id DESC"
+ query_args.append(to_pkey)
+ query_args.append(from_pkey)
+ else:
+ # from=5 to=-1 >> from 5 to now >> id>5
+ query += " AND %s.id > ?" % table_name
+ query_args.append(from_pkey)
+
+ query += group_by + order_by
+
+ if limit and limit > 0:
+ query += " LIMIT ?"
+ query_args.append(str(limit))
+
+ return (query, query_args)
+
+ def _as_events(self, cursor, table, from_pkey):
+ data_entries = table.decode_results(cursor)
+ last_pkey = from_pkey
+ if data_entries:
+ last_pkey = data_entries[-1].id
+
+ events = [
+ entry.as_event(self.event_factory).get_dict()
+ for entry in data_entries
+ ]
+
+ return (events, last_pkey)
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
new file mode 100644
index 0000000000..aa41e2ad7f
--- /dev/null
+++ b/synapse/storage/transactions.py
@@ -0,0 +1,287 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 matrix.org
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import SQLBaseStore, Table
+from .pdu import PdusTable
+
+from collections import namedtuple
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class TransactionStore(SQLBaseStore):
+ """A collection of queries for handling PDUs.
+ """
+
+ def get_received_txn_response(self, transaction_id, origin):
+ """For an incoming transaction from a given origin, check if we have
+ already responded to it. If so, return the response code and response
+ body (as a dict).
+
+ Args:
+ transaction_id (str)
+ origin(str)
+
+ Returns:
+ tuple: None if we have not previously responded to
+ this transaction or a 2-tuple of (int, dict)
+ """
+
+ return self._db_pool.runInteraction(
+ self._get_received_txn_response, transaction_id, origin
+ )
+
+ def _get_received_txn_response(self, txn, transaction_id, origin):
+ where_clause = "transaction_id = ? AND origin = ?"
+ query = ReceivedTransactionsTable.select_statement(where_clause)
+
+ txn.execute(query, (transaction_id, origin))
+
+ results = ReceivedTransactionsTable.decode_results(txn.fetchall())
+
+ if results and results[0].response_code:
+ return (results[0].response_code, results[0].response_json)
+ else:
+ return None
+
+ def set_received_txn_response(self, transaction_id, origin, code,
+ response_dict):
+ """Persist the response we returened for an incoming transaction, and
+ should return for subsequent transactions with the same transaction_id
+ and origin.
+
+ Args:
+ txn
+ transaction_id (str)
+ origin (str)
+ code (int)
+ response_json (str)
+ """
+
+ return self._db_pool.runInteraction(
+ self._set_received_txn_response,
+ transaction_id, origin, code, response_dict
+ )
+
+ def _set_received_txn_response(self, txn, transaction_id, origin, code,
+ response_json):
+ query = (
+ "UPDATE %s "
+ "SET response_code = ?, response_json = ? "
+ "WHERE transaction_id = ? AND origin = ?"
+ ) % ReceivedTransactionsTable.table_name
+
+ txn.execute(query, (code, response_json, transaction_id, origin))
+
+ def prep_send_transaction(self, transaction_id, destination, ts, pdu_list):
+ """Persists an outgoing transaction and calculates the values for the
+ previous transaction id list.
+
+ This should be called before sending the transaction so that it has the
+ correct value for the `prev_ids` key.
+
+ Args:
+ transaction_id (str)
+ destination (str)
+ ts (int)
+ pdu_list (list)
+
+ Returns:
+ list: A list of previous transaction ids.
+ """
+
+ return self._db_pool.runInteraction(
+ self._prep_send_transaction,
+ transaction_id, destination, ts, pdu_list
+ )
+
+ def _prep_send_transaction(self, txn, transaction_id, destination, ts,
+ pdu_list):
+
+ # First we find out what the prev_txs should be.
+ # Since we know that we are only sending one transaction at a time,
+ # we can simply take the last one.
+ query = "%s ORDER BY id DESC LIMIT 1" % (
+ SentTransactions.select_statement("destination = ?"),
+ )
+
+ results = txn.execute(query, (destination,))
+ results = SentTransactions.decode_results(results)
+
+ prev_txns = [r.transaction_id for r in results]
+
+ # Actually add the new transaction to the sent_transactions table.
+
+ query = SentTransactions.insert_statement()
+ txn.execute(query, SentTransactions.EntryType(
+ None,
+ transaction_id=transaction_id,
+ destination=destination,
+ ts=ts,
+ response_code=0,
+ response_json=None
+ ))
+
+ # Update the tx id -> pdu id mapping
+
+ values = [
+ (transaction_id, destination, pdu[0], pdu[1])
+ for pdu in pdu_list
+ ]
+
+ logger.debug("Inserting: %s", repr(values))
+
+ query = TransactionsToPduTable.insert_statement()
+ txn.executemany(query, values)
+
+ return prev_txns
+
+ def delivered_txn(self, transaction_id, destination, code, response_dict):
+ """Persists the response for an outgoing transaction.
+
+ Args:
+ transaction_id (str)
+ destination (str)
+ code (int)
+ response_json (str)
+ """
+ return self._db_pool.runInteraction(
+ self._delivered_txn,
+ transaction_id, destination, code, response_dict
+ )
+
+ def _delivered_txn(cls, txn, transaction_id, destination,
+ code, response_json):
+ query = (
+ "UPDATE %s "
+ "SET response_code = ?, response_json = ? "
+ "WHERE transaction_id = ? AND destination = ?"
+ ) % SentTransactions.table_name
+
+ txn.execute(query, (code, response_json, transaction_id, destination))
+
+ def get_transactions_after(self, transaction_id, destination):
+ """Get all transactions after a given local transaction_id.
+
+ Args:
+ transaction_id (str)
+ destination (str)
+
+ Returns:
+ list: A list of `ReceivedTransactionsTable.EntryType`
+ """
+ return self._db_pool.runInteraction(
+ self._get_transactions_after, transaction_id, destination
+ )
+
+ def _get_transactions_after(cls, txn, transaction_id, destination):
+ where = (
+ "destination = ? AND id > (select id FROM %s WHERE "
+ "transaction_id = ? AND destination = ?)"
+ ) % (
+ SentTransactions.table_name
+ )
+ query = SentTransactions.select_statement(where)
+
+ txn.execute(query, (destination, transaction_id, destination))
+
+ return ReceivedTransactionsTable.decode_results(txn.fetchall())
+
+ def get_pdus_after_transaction(self, transaction_id, destination):
+ """For a given local transaction_id that we sent to a given destination
+ home server, return a list of PDUs that were sent to that destination
+ after it.
+
+ Args:
+ txn
+ transaction_id (str)
+ destination (str)
+
+ Returns
+ list: A list of PduTuple
+ """
+ return self._db_pool.runInteraction(
+ self._get_pdus_after_transaction,
+ transaction_id, destination
+ )
+
+ def _get_pdus_after_transaction(self, txn, transaction_id, destination):
+
+ # Query that first get's all transaction_ids with an id greater than
+ # the one given from the `sent_transactions` table. Then JOIN on this
+ # from the `tx->pdu` table to get a list of (pdu_id, origin) that
+ # specify the pdus that were sent in those transactions.
+ query = (
+ "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp "
+ "INNER JOIN %(sent_tx)s as st "
+ "ON tp.transaction_id = st.transaction_id "
+ "AND tp.destination = st.destination "
+ "WHERE st.id > ("
+ "SELECT id FROM %(sent_tx)s "
+ "WHERE transaction_id = ? AND destination = ?"
+ ) % {
+ "tx_pdu": TransactionsToPduTable.table_name,
+ "sent_tx": SentTransactions.table_name,
+ }
+
+ txn.execute(query, (transaction_id, destination))
+
+ pdus = PdusTable.decode_results(txn.fetchall())
+
+ return self._get_pdu_tuples(txn, pdus)
+
+
+class ReceivedTransactionsTable(Table):
+ table_name = "received_transactions"
+
+ fields = [
+ "transaction_id",
+ "origin",
+ "ts",
+ "response_code",
+ "response_json",
+ "has_been_referenced",
+ ]
+
+ EntryType = namedtuple("ReceivedTransactionsEntry", fields)
+
+
+class SentTransactions(Table):
+ table_name = "sent_transactions"
+
+ fields = [
+ "id",
+ "transaction_id",
+ "destination",
+ "ts",
+ "response_code",
+ "response_json",
+ ]
+
+ EntryType = namedtuple("SentTransactionsEntry", fields)
+
+
+class TransactionsToPduTable(Table):
+ table_name = "transaction_id_to_pdu"
+
+ fields = [
+ "transaction_id",
+ "destination",
+ "pdu_id",
+ "pdu_origin",
+ ]
+
+ EntryType = namedtuple("TransactionsToPduEntry", fields)
|