diff options
author | matrix.org <matrix@matrix.org> | 2014-08-12 15:10:52 +0100 |
---|---|---|
committer | matrix.org <matrix@matrix.org> | 2014-08-12 15:10:52 +0100 |
commit | 4f475c7697722e946e39e42f38f3dd03a95d8765 (patch) | |
tree | 076d96d3809fb836c7245fd9f7960e7b75888a77 /synapse/storage | |
download | synapse-4f475c7697722e946e39e42f38f3dd03a95d8765.tar.xz |
Reference Matrix Home Server
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/__init__.py | 117 | ||||
-rw-r--r-- | synapse/storage/_base.py | 405 | ||||
-rw-r--r-- | synapse/storage/directory.py | 93 | ||||
-rw-r--r-- | synapse/storage/feedback.py | 74 | ||||
-rw-r--r-- | synapse/storage/message.py | 80 | ||||
-rw-r--r-- | synapse/storage/pdu.py | 993 | ||||
-rw-r--r-- | synapse/storage/presence.py | 103 | ||||
-rw-r--r-- | synapse/storage/profile.py | 51 | ||||
-rw-r--r-- | synapse/storage/registration.py | 113 | ||||
-rw-r--r-- | synapse/storage/room.py | 129 | ||||
-rw-r--r-- | synapse/storage/roomdata.py | 84 | ||||
-rw-r--r-- | synapse/storage/roommember.py | 171 | ||||
-rw-r--r-- | synapse/storage/schema/edge_pdus.sql | 31 | ||||
-rw-r--r-- | synapse/storage/schema/im.sql | 54 | ||||
-rw-r--r-- | synapse/storage/schema/pdu.sql | 106 | ||||
-rw-r--r-- | synapse/storage/schema/presence.sql | 37 | ||||
-rw-r--r-- | synapse/storage/schema/profiles.sql | 20 | ||||
-rw-r--r-- | synapse/storage/schema/room_aliases.sql | 12 | ||||
-rw-r--r-- | synapse/storage/schema/transactions.sql | 61 | ||||
-rw-r--r-- | synapse/storage/schema/users.sql | 31 | ||||
-rw-r--r-- | synapse/storage/stream.py | 282 | ||||
-rw-r--r-- | synapse/storage/transactions.py | 287 |
22 files changed, 3334 insertions, 0 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py new file mode 100644 index 0000000000..ec93f9f8a7 --- /dev/null +++ b/synapse/storage/__init__.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.api.events.room import ( + RoomMemberEvent, MessageEvent, RoomTopicEvent, FeedbackEvent, + RoomConfigEvent +) + +from .directory import DirectoryStore +from .feedback import FeedbackStore +from .message import MessageStore +from .presence import PresenceStore +from .profile import ProfileStore +from .registration import RegistrationStore +from .room import RoomStore +from .roommember import RoomMemberStore +from .roomdata import RoomDataStore +from .stream import StreamStore +from .pdu import StatePduStore, PduStore +from .transactions import TransactionStore + +import json +import os + + +class DataStore(RoomDataStore, RoomMemberStore, MessageStore, RoomStore, + RegistrationStore, StreamStore, ProfileStore, FeedbackStore, + PresenceStore, PduStore, StatePduStore, TransactionStore, + DirectoryStore): + + def __init__(self, hs): + super(DataStore, self).__init__(hs) + self.event_factory = hs.get_event_factory() + self.hs = hs + + def persist_event(self, event): + if event.type == MessageEvent.TYPE: + return self.store_message( + user_id=event.user_id, + room_id=event.room_id, + msg_id=event.msg_id, + content=json.dumps(event.content) + ) + elif event.type == RoomMemberEvent.TYPE: + return self.store_room_member( + user_id=event.target_user_id, + sender=event.user_id, + room_id=event.room_id, + content=event.content, + membership=event.content["membership"] + ) + elif event.type == FeedbackEvent.TYPE: + return self.store_feedback( + room_id=event.room_id, + msg_id=event.msg_id, + msg_sender_id=event.msg_sender_id, + fb_sender_id=event.user_id, + fb_type=event.feedback_type, + content=json.dumps(event.content) + ) + elif event.type == RoomTopicEvent.TYPE: + return self.store_room_data( + room_id=event.room_id, + etype=event.type, + state_key=event.state_key, + content=json.dumps(event.content) + ) + elif event.type == RoomConfigEvent.TYPE: + if "visibility" in event.content: + visibility = event.content["visibility"] + return self.store_room_config( + room_id=event.room_id, + visibility=visibility + ) + + else: + raise NotImplementedError( + "Don't know how to persist type=%s" % event.type + ) + + +def schema_path(schema): + """ Get a filesystem path for the named database schema + + Args: + schema: Name of the database schema. + Returns: + A filesystem path pointing at a ".sql" file. + + """ + dir_path = os.path.dirname(__file__) + schemaPath = os.path.join(dir_path, "schema", schema + ".sql") + return schemaPath + + +def read_schema(schema): + """ Read the named database schema. + + Args: + schema: Name of the datbase schema. + Returns: + A string containing the database schema. + """ + with open(schema_path(schema)) as schema_file: + return schema_file.read() diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py new file mode 100644 index 0000000000..4d98a6fd0d --- /dev/null +++ b/synapse/storage/_base.py @@ -0,0 +1,405 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from twisted.internet import defer + +from synapse.api.errors import StoreError + +import collections + +logger = logging.getLogger(__name__) + + +class SQLBaseStore(object): + + def __init__(self, hs): + self._db_pool = hs.get_db_pool() + + def cursor_to_dict(self, cursor): + """Converts a SQL cursor into an list of dicts. + + Args: + cursor : The DBAPI cursor which has executed a query. + Returns: + A list of dicts where the key is the column header. + """ + col_headers = list(column[0] for column in cursor.description) + results = list( + dict(zip(col_headers, row)) for row in cursor.fetchall() + ) + return results + + def _execute(self, decoder, query, *args): + """Runs a single query for a result set. + + Args: + decoder - The function which can resolve the cursor results to + something meaningful. + query - The query string to execute + *args - Query args. + Returns: + The result of decoder(results) + """ + logger.debug( + "[SQL] %s Args=%s Func=%s", query, args, decoder.__name__ + ) + + def interaction(txn): + cursor = txn.execute(query, args) + return decoder(cursor) + return self._db_pool.runInteraction(interaction) + + # "Simple" SQL API methods that operate on a single table with no JOINs, + # no complex WHERE clauses, just a dict of values for columns. + + def _simple_insert(self, table, values): + """Executes an INSERT query on the named table. + + Args: + table : string giving the table name + values : dict of new column names and values for them + """ + sql = "INSERT INTO %s (%s) VALUES(%s)" % ( + table, + ", ".join(k for k in values), + ", ".join("?" for k in values) + ) + + def func(txn): + txn.execute(sql, values.values()) + return txn.lastrowid + return self._db_pool.runInteraction(func) + + def _simple_select_one(self, table, keyvalues, retcols, + allow_none=False): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning a single column from it. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcols : list of strings giving the names of the columns to return + + allow_none : If true, return None instead of failing if the SELECT + statement returns no rows + """ + return self._simple_selectupdate_one( + table, keyvalues, retcols=retcols, allow_none=allow_none + ) + + @defer.inlineCallbacks + def _simple_select_one_onecol(self, table, keyvalues, retcol, + allow_none=False): + """Executes a SELECT query on the named table, which is expected to + return a single row, returning a single column from it." + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + retcol : string giving the name of the column to return + """ + ret = yield self._simple_select_one( + table=table, + keyvalues=keyvalues, + retcols=[retcol], + allow_none=allow_none + ) + + if ret: + defer.returnValue(ret[retcol]) + else: + defer.returnValue(None) + + @defer.inlineCallbacks + def _simple_select_onecol(self, table, keyvalues, retcol): + """Executes a SELECT query on the named table, which returns a list + comprising of the values of the named column from the selected rows. + + Args: + table (str): table name + keyvalues (dict): column names and values to select the rows with + retcol (str): column whos value we wish to retrieve. + + Returns: + Deferred: Results in a list + """ + sql = "SELECT %(retcol)s FROM %(table)s WHERE %(where)s" % { + "retcol": retcol, + "table": table, + "where": " AND ".join("%s = ?" % k for k in keyvalues.keys()), + } + + def func(txn): + txn.execute(sql, keyvalues.values()) + return txn.fetchall() + + res = yield self._db_pool.runInteraction(func) + + defer.returnValue([r[0] for r in res]) + + def _simple_select_list(self, table, keyvalues, retcols): + """Executes a SELECT query on the named table, which may return zero or + more rows, returning the result as a list of dicts. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the rows with + retcols : list of strings giving the names of the columns to return + """ + sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k) for k in keyvalues) + ) + + def func(txn): + txn.execute(sql, keyvalues.values()) + return self.cursor_to_dict(txn) + + return self._db_pool.runInteraction(func) + + def _simple_update_one(self, table, keyvalues, updatevalues, + retcols=None): + """Executes an UPDATE query on the named table, setting new values for + columns in a row matching the key values. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + updatevalues : dict giving column names and values to update + retcols : optional list of column names to return + + If present, retcols gives a list of column names on which to perform + a SELECT statement *before* performing the UPDATE statement. The values + of these will be returned in a dict. + + These are performed within the same transaction, allowing an atomic + get-and-set. This can be used to implement compare-and-set by putting + the update column in the 'keyvalues' dict as well. + """ + return self._simple_selectupdate_one(table, keyvalues, updatevalues, + retcols=retcols) + + def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None, + retcols=None, allow_none=False): + """ Combined SELECT then UPDATE.""" + if retcols: + select_sql = "SELECT %s FROM %s WHERE %s" % ( + ", ".join(retcols), + table, + " AND ".join("%s = ?" % (k) for k in keyvalues) + ) + + if updatevalues: + update_sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k) for k in updatevalues), + " AND ".join("%s = ?" % (k) for k in keyvalues) + ) + + def func(txn): + ret = None + if retcols: + txn.execute(select_sql, keyvalues.values()) + + row = txn.fetchone() + if not row: + if allow_none: + return None + raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched") + + ret = dict(zip(retcols, row)) + + if updatevalues: + txn.execute( + update_sql, + updatevalues.values() + keyvalues.values() + ) + + if txn.rowcount == 0: + raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "More than one row matched") + + return ret + return self._db_pool.runInteraction(func) + + def _simple_delete_one(self, table, keyvalues): + """Executes a DELETE query on the named table, expecting to delete a + single row. + + Args: + table : string giving the table name + keyvalues : dict of column names and values to select the row with + """ + sql = "DELETE FROM %s WHERE %s" % ( + table, + " AND ".join("%s = ?" % (k) for k in keyvalues) + ) + + def func(txn): + txn.execute(sql, keyvalues.values()) + if txn.rowcount == 0: + raise StoreError(404, "No row found") + if txn.rowcount > 1: + raise StoreError(500, "more than one row matched") + return self._db_pool.runInteraction(func) + + def _simple_max_id(self, table): + """Executes a SELECT query on the named table, expecting to return the + max value for the column "id". + + Args: + table : string giving the table name + """ + sql = "SELECT MAX(id) AS id FROM %s" % table + + def func(txn): + txn.execute(sql) + max_id = self.cursor_to_dict(txn)[0]["id"] + if max_id is None: + return 0 + return max_id + + return self._db_pool.runInteraction(func) + + +class Table(object): + """ A base class used to store information about a particular table. + """ + + table_name = None + """ str: The name of the table """ + + fields = None + """ list: The field names """ + + EntryType = None + """ Type: A tuple type used to decode the results """ + + _select_where_clause = "SELECT %s FROM %s WHERE %s" + _select_clause = "SELECT %s FROM %s" + _insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)" + + @classmethod + def select_statement(cls, where_clause=None): + """ + Args: + where_clause (str): The WHERE clause to use. + + Returns: + str: An SQL statement to select rows from the table with the given + WHERE clause. + """ + if where_clause: + return cls._select_where_clause % ( + ", ".join(cls.fields), + cls.table_name, + where_clause + ) + else: + return cls._select_clause % ( + ", ".join(cls.fields), + cls.table_name, + ) + + @classmethod + def insert_statement(cls): + return cls._insert_clause % ( + cls.table_name, + ", ".join(cls.fields), + ", ".join(["?"] * len(cls.fields)), + ) + + @classmethod + def decode_single_result(cls, results): + """ Given an iterable of tuples, return a single instance of + `EntryType` or None if the iterable is empty + Args: + results (list): The results list to convert to `EntryType` + Returns: + EntryType: An instance of `EntryType` + """ + results = list(results) + if results: + return cls.EntryType(*results[0]) + else: + return None + + @classmethod + def decode_results(cls, results): + """ Given an iterable of tuples, return a list of `EntryType` + Args: + results (list): The results list to convert to `EntryType` + + Returns: + list: A list of `EntryType` + """ + return [cls.EntryType(*row) for row in results] + + @classmethod + def get_fields_string(cls, prefix=None): + if prefix: + to_join = ("%s.%s" % (prefix, f) for f in cls.fields) + else: + to_join = cls.fields + + return ", ".join(to_join) + + +class JoinHelper(object): + """ Used to help do joins on tables by looking at the tables' fields and + creating a list of unique fields to use with SELECTs and a namedtuple + to dump the results into. + + Attributes: + taples (list): List of `Table` classes + EntryType (type) + """ + + def __init__(self, *tables): + self.tables = tables + + res = [] + for table in self.tables: + res += [f for f in table.fields if f not in res] + + self.EntryType = collections.namedtuple("JoinHelperEntry", res) + + def get_fields(self, **prefixes): + """Get a string representing a list of fields for use in SELECT + statements with the given prefixes applied to each. + + For example:: + + JoinHelper(PdusTable, StateTable).get_fields( + PdusTable="pdus", + StateTable="state" + ) + """ + res = [] + for field in self.EntryType._fields: + for table in self.tables: + if field in table.fields: + res.append("%s.%s" % (prefixes[table.__name__], field)) + break + + return ", ".join(res) + + def decode_results(self, rows): + return [self.EntryType(*row) for row in rows] diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py new file mode 100644 index 0000000000..71fa9d9c9c --- /dev/null +++ b/synapse/storage/directory.py @@ -0,0 +1,93 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore +from twisted.internet import defer + +from collections import namedtuple + + +RoomAliasMapping = namedtuple( + "RoomAliasMapping", + ("room_id", "room_alias", "servers",) +) + + +class DirectoryStore(SQLBaseStore): + + @defer.inlineCallbacks + def get_association_from_room_alias(self, room_alias): + """ Get's the room_id and server list for a given room_alias + + Args: + room_alias (RoomAlias) + + Returns: + Deferred: results in namedtuple with keys "room_id" and + "servers" or None if no association can be found + """ + room_id = yield self._simple_select_one_onecol( + "room_aliases", + {"room_alias": room_alias.to_string()}, + "room_id", + allow_none=True, + ) + + if not room_id: + defer.returnValue(None) + return + + servers = yield self._simple_select_onecol( + "room_alias_servers", + {"room_alias": room_alias.to_string()}, + "server", + ) + + if not servers: + defer.returnValue(None) + return + + defer.returnValue( + RoomAliasMapping(room_id, room_alias.to_string(), servers) + ) + + @defer.inlineCallbacks + def create_room_alias_association(self, room_alias, room_id, servers): + """ Creates an associatin between a room alias and room_id/servers + + Args: + room_alias (RoomAlias) + room_id (str) + servers (list) + + Returns: + Deferred + """ + yield self._simple_insert( + "room_aliases", + { + "room_alias": room_alias.to_string(), + "room_id": room_id, + }, + ) + + for server in servers: + # TODO(erikj): Fix this to bulk insert + yield self._simple_insert( + "room_alias_servers", + { + "room_alias": room_alias.to_string(), + "server": server, + } + ) diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py new file mode 100644 index 0000000000..2b421e3342 --- /dev/null +++ b/synapse/storage/feedback.py @@ -0,0 +1,74 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore, Table +from synapse.api.events.room import FeedbackEvent + +import collections +import json + + +class FeedbackStore(SQLBaseStore): + + def store_feedback(self, room_id, msg_id, msg_sender_id, + fb_sender_id, fb_type, content): + return self._simple_insert(FeedbackTable.table_name, dict( + room_id=room_id, + msg_id=msg_id, + msg_sender_id=msg_sender_id, + fb_sender_id=fb_sender_id, + fb_type=fb_type, + content=content, + )) + + def get_feedback(self, room_id=None, msg_id=None, msg_sender_id=None, + fb_sender_id=None, fb_type=None): + query = FeedbackTable.select_statement( + "msg_sender_id = ? AND room_id = ? AND msg_id = ? " + + "AND fb_sender_id = ? AND feedback_type = ? " + + "ORDER BY id DESC LIMIT 1") + return self._execute( + FeedbackTable.decode_single_result, + query, msg_sender_id, room_id, msg_id, fb_sender_id, fb_type, + ) + + def get_max_feedback_id(self): + return self._simple_max_id(FeedbackTable.table_name) + + +class FeedbackTable(Table): + table_name = "feedback" + + fields = [ + "id", + "content", + "feedback_type", + "fb_sender_id", + "msg_id", + "room_id", + "msg_sender_id" + ] + + class EntryType(collections.namedtuple("FeedbackEntry", fields)): + + def as_event(self, event_factory): + return event_factory.create_event( + etype=FeedbackEvent.TYPE, + room_id=self.room_id, + msg_id=self.msg_id, + msg_sender_id=self.msg_sender_id, + user_id=self.fb_sender_id, + feedback_type=self.feedback_type, + content=json.loads(self.content), + ) diff --git a/synapse/storage/message.py b/synapse/storage/message.py new file mode 100644 index 0000000000..4822fa709d --- /dev/null +++ b/synapse/storage/message.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore, Table +from synapse.api.events.room import MessageEvent + +import collections +import json + + +class MessageStore(SQLBaseStore): + + def get_message(self, user_id, room_id, msg_id): + """Get a message from the store. + + Args: + user_id (str): The ID of the user who sent the message. + room_id (str): The room the message was sent in. + msg_id (str): The unique ID for this user/room combo. + """ + query = MessagesTable.select_statement( + "user_id = ? AND room_id = ? AND msg_id = ? " + + "ORDER BY id DESC LIMIT 1") + return self._execute( + MessagesTable.decode_single_result, + query, user_id, room_id, msg_id, + ) + + def store_message(self, user_id, room_id, msg_id, content): + """Store a message in the store. + + Args: + user_id (str): The ID of the user who sent the message. + room_id (str): The room the message was sent in. + msg_id (str): The unique ID for this user/room combo. + content (str): The content of the message (JSON) + """ + return self._simple_insert(MessagesTable.table_name, dict( + user_id=user_id, + room_id=room_id, + msg_id=msg_id, + content=content, + )) + + def get_max_message_id(self): + return self._simple_max_id(MessagesTable.table_name) + + +class MessagesTable(Table): + table_name = "messages" + + fields = [ + "id", + "user_id", + "room_id", + "msg_id", + "content" + ] + + class EntryType(collections.namedtuple("MessageEntry", fields)): + + def as_event(self, event_factory): + return event_factory.create_event( + etype=MessageEvent.TYPE, + room_id=self.room_id, + user_id=self.user_id, + msg_id=self.msg_id, + content=json.loads(self.content), + ) diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py new file mode 100644 index 0000000000..a1cdde0a3b --- /dev/null +++ b/synapse/storage/pdu.py @@ -0,0 +1,993 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore, Table, JoinHelper + +from synapse.util.logutils import log_function + +from collections import namedtuple + +import logging + +logger = logging.getLogger(__name__) + + +class PduStore(SQLBaseStore): + """A collection of queries for handling PDUs. + """ + + def get_pdu(self, pdu_id, origin): + """Given a pdu_id and origin, get a PDU. + + Args: + txn + pdu_id (str) + origin (str) + + Returns: + PduTuple: If the pdu does not exist in the database, returns None + """ + + return self._db_pool.runInteraction( + self._get_pdu_tuple, pdu_id, origin + ) + + def _get_pdu_tuple(self, txn, pdu_id, origin): + res = self._get_pdu_tuples(txn, [(pdu_id, origin)]) + return res[0] if res else None + + def _get_pdu_tuples(self, txn, pdu_id_tuples): + results = [] + for pdu_id, origin in pdu_id_tuples: + txn.execute( + PduEdgesTable.select_statement("pdu_id = ? AND origin = ?"), + (pdu_id, origin) + ) + + edges = [ + (r.prev_pdu_id, r.prev_origin) + for r in PduEdgesTable.decode_results(txn.fetchall()) + ] + + query = ( + "SELECT %(fields)s FROM %(pdus)s as p " + "LEFT JOIN %(state)s as s " + "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " + "WHERE p.pdu_id = ? AND p.origin = ? " + ) % { + "fields": _pdu_state_joiner.get_fields( + PdusTable="p", StatePdusTable="s"), + "pdus": PdusTable.table_name, + "state": StatePdusTable.table_name, + } + + txn.execute(query, (pdu_id, origin)) + + row = txn.fetchone() + if row: + results.append(PduTuple(PduEntry(*row), edges)) + + return results + + def get_current_state_for_context(self, context): + """Get a list of PDUs that represent the current state for a given + context + + Args: + context (str) + + Returns: + list: A list of PduTuples + """ + + return self._db_pool.runInteraction( + self._get_current_state_for_context, + context + ) + + def _get_current_state_for_context(self, txn, context): + query = ( + "SELECT pdu_id, origin FROM %s WHERE context = ?" + % CurrentStateTable.table_name + ) + + logger.debug("get_current_state %s, Args=%s", query, context) + txn.execute(query, (context,)) + + res = txn.fetchall() + + logger.debug("get_current_state %d results", len(res)) + + return self._get_pdu_tuples(txn, res) + + def persist_pdu(self, prev_pdus, **cols): + """Inserts a (non-state) PDU into the database. + + Args: + txn, + prev_pdus (list) + **cols: The columns to insert into the PdusTable. + """ + return self._db_pool.runInteraction( + self._persist_pdu, prev_pdus, cols + ) + + def _persist_pdu(self, txn, prev_pdus, cols): + entry = PdusTable.EntryType( + **{k: cols.get(k, None) for k in PdusTable.fields} + ) + + txn.execute(PdusTable.insert_statement(), entry) + + self._handle_prev_pdus( + txn, entry.outlier, entry.pdu_id, entry.origin, + prev_pdus, entry.context + ) + + def mark_pdu_as_processed(self, pdu_id, pdu_origin): + """Mark a received PDU as processed. + + Args: + txn + pdu_id (str) + pdu_origin (str) + """ + + return self._db_pool.runInteraction( + self._mark_as_processed, pdu_id, pdu_origin + ) + + def _mark_as_processed(self, txn, pdu_id, pdu_origin): + txn.execute("UPDATE %s SET have_processed = 1" % PdusTable.table_name) + + def get_all_pdus_from_context(self, context): + """Get a list of all PDUs for a given context.""" + return self._db_pool.runInteraction( + self._get_all_pdus_from_context, context, + ) + + def _get_all_pdus_from_context(self, txn, context): + query = ( + "SELECT pdu_id, origin FROM %s " + "WHERE context = ?" + ) % PdusTable.table_name + + txn.execute(query, (context,)) + + return self._get_pdu_tuples(txn, txn.fetchall()) + + def get_pagination(self, context, pdu_list, limit): + """Get a list of Pdus for a given topic that occured before (and + including) the pdus in pdu_list. Return a list of max size `limit`. + + Args: + txn + context (str) + pdu_list (list) + limit (int) + + Return: + list: A list of PduTuples + """ + return self._db_pool.runInteraction( + self._get_paginate, context, pdu_list, limit + ) + + def _get_paginate(self, txn, context, pdu_list, limit): + logger.debug( + "paginate: %s, %s, %s", + context, repr(pdu_list), limit + ) + + # We seed the pdu_results with the things from the pdu_list. + pdu_results = pdu_list + + front = pdu_list + + query = ( + "SELECT prev_pdu_id, prev_origin FROM %(edges_table)s " + "WHERE context = ? AND pdu_id = ? AND origin = ? " + "LIMIT ?" + ) % { + "edges_table": PduEdgesTable.table_name, + } + + # We iterate through all pdu_ids in `front` to select their previous + # pdus. These are dumped in `new_front`. We continue until we reach the + # limit *or* new_front is empty (i.e., we've run out of things to + # select + while front and len(pdu_results) < limit: + + new_front = [] + for pdu_id, origin in front: + logger.debug( + "_paginate_interaction: i=%s, o=%s", + pdu_id, origin + ) + + txn.execute( + query, + (context, pdu_id, origin, limit - len(pdu_results)) + ) + + for row in txn.fetchall(): + logger.debug( + "_paginate_interaction: got i=%s, o=%s", + *row + ) + new_front.append(row) + + front = new_front + pdu_results += new_front + + # We also want to update the `prev_pdus` attributes before returning. + return self._get_pdu_tuples(txn, pdu_results) + + def get_min_depth_for_context(self, context): + """Get the current minimum depth for a context + + Args: + txn + context (str) + """ + return self._db_pool.runInteraction( + self._get_min_depth_for_context, context + ) + + def _get_min_depth_for_context(self, txn, context): + return self._get_min_depth_interaction(txn, context) + + def _get_min_depth_interaction(self, txn, context): + txn.execute( + "SELECT min_depth FROM %s WHERE context = ?" + % ContextDepthTable.table_name, + (context,) + ) + + row = txn.fetchone() + + return row[0] if row else None + + def update_min_depth_for_context(self, context, depth): + """Update the minimum `depth` of the given context, which is the line + where we stop paginating backwards on. + + Args: + context (str) + depth (int) + """ + return self._db_pool.runInteraction( + self._update_min_depth_for_context, context, depth + ) + + def _update_min_depth_for_context(self, txn, context, depth): + min_depth = self._get_min_depth_interaction(txn, context) + + do_insert = depth < min_depth if min_depth else True + + if do_insert: + txn.execute( + "INSERT OR REPLACE INTO %s (context, min_depth) " + "VALUES (?,?)" % ContextDepthTable.table_name, + (context, depth) + ) + + def get_latest_pdus_in_context(self, context): + """Get's a list of the most current pdus for a given context. This is + used when we are sending a Pdu and need to fill out the `prev_pdus` + key + + Args: + txn + context + """ + return self._db_pool.runInteraction( + self._get_latest_pdus_in_context, context + ) + + def _get_latest_pdus_in_context(self, txn, context): + query = ( + "SELECT p.pdu_id, p.origin, p.depth FROM %(pdus)s as p " + "INNER JOIN %(forward)s as f ON p.pdu_id = f.pdu_id " + "AND f.origin = p.origin " + "WHERE f.context = ?" + ) % { + "pdus": PdusTable.table_name, + "forward": PduForwardExtremitiesTable.table_name, + } + + logger.debug("get_prev query: %s", query) + + txn.execute( + query, + (context, ) + ) + + results = txn.fetchall() + + return [(row[0], row[1], row[2]) for row in results] + + def get_oldest_pdus_in_context(self, context): + """Get a list of Pdus that we paginated beyond yet (and haven't seen). + This list is used when we want to paginate backwards and is the list we + send to the remote server. + + Args: + txn + context (str) + + Returns: + list: A list of PduIdTuple. + """ + return self._db_pool.runInteraction( + self._get_oldest_pdus_in_context, context + ) + + def _get_oldest_pdus_in_context(self, txn, context): + txn.execute( + "SELECT pdu_id, origin FROM %(back)s WHERE context = ?" + % {"back": PduBackwardExtremitiesTable.table_name, }, + (context,) + ) + return [PduIdTuple(i, o) for i, o in txn.fetchall()] + + def is_pdu_new(self, pdu_id, origin, context, depth): + """For a given Pdu, try and figure out if it's 'new', i.e., if it's + not something we got randomly from the past, for example when we + request the current state of the room that will probably return a bunch + of pdus from before we joined. + + Args: + txn + pdu_id (str) + origin (str) + context (str) + depth (int) + + Returns: + bool + """ + + return self._db_pool.runInteraction( + self._is_pdu_new, + pdu_id=pdu_id, + origin=origin, + context=context, + depth=depth + ) + + def _is_pdu_new(self, txn, pdu_id, origin, context, depth): + # If depth > min depth in back table, then we classify it as new. + # OR if there is nothing in the back table, then it kinda needs to + # be a new thing. + query = ( + "SELECT min(p.depth) FROM %(edges)s as e " + "INNER JOIN %(back)s as b " + "ON e.prev_pdu_id = b.pdu_id AND e.prev_origin = b.origin " + "INNER JOIN %(pdus)s as p " + "ON e.pdu_id = p.pdu_id AND p.origin = e.origin " + "WHERE p.context = ?" + ) % { + "pdus": PdusTable.table_name, + "edges": PduEdgesTable.table_name, + "back": PduBackwardExtremitiesTable.table_name, + } + + txn.execute(query, (context,)) + + min_depth, = txn.fetchone() + + if not min_depth or depth > int(min_depth): + logger.debug( + "is_new true: id=%s, o=%s, d=%s min_depth=%s", + pdu_id, origin, depth, min_depth + ) + return True + + # If this pdu is in the forwards table, then it also is a new one + query = ( + "SELECT * FROM %(forward)s WHERE pdu_id = ? AND origin = ?" + ) % { + "forward": PduForwardExtremitiesTable.table_name, + } + + txn.execute(query, (pdu_id, origin)) + + # Did we get anything? + if txn.fetchall(): + logger.debug( + "is_new true: id=%s, o=%s, d=%s was forward", + pdu_id, origin, depth + ) + return True + + logger.debug( + "is_new false: id=%s, o=%s, d=%s", + pdu_id, origin, depth + ) + + # FINE THEN. It's probably old. + return False + + @staticmethod + @log_function + def _handle_prev_pdus(txn, outlier, pdu_id, origin, prev_pdus, + context): + txn.executemany( + PduEdgesTable.insert_statement(), + [(pdu_id, origin, p[0], p[1], context) for p in prev_pdus] + ) + + # Update the extremities table if this is not an outlier. + if not outlier: + + # First, we delete the new one from the forwards extremities table. + query = ( + "DELETE FROM %s WHERE pdu_id = ? AND origin = ?" + % PduForwardExtremitiesTable.table_name + ) + txn.executemany(query, prev_pdus) + + # We only insert as a forward extremety the new pdu if there are no + # other pdus that reference it as a prev pdu + query = ( + "INSERT INTO %(table)s (pdu_id, origin, context) " + "SELECT ?, ?, ? WHERE NOT EXISTS (" + "SELECT 1 FROM %(pdu_edges)s WHERE " + "prev_pdu_id = ? AND prev_origin = ?" + ")" + ) % { + "table": PduForwardExtremitiesTable.table_name, + "pdu_edges": PduEdgesTable.table_name + } + + logger.debug("query: %s", query) + + txn.execute(query, (pdu_id, origin, context, pdu_id, origin)) + + # Insert all the prev_pdus as a backwards thing, they'll get + # deleted in a second if they're incorrect anyway. + txn.executemany( + PduBackwardExtremitiesTable.insert_statement(), + [(i, o, context) for i, o in prev_pdus] + ) + + # Also delete from the backwards extremities table all ones that + # reference pdus that we have already seen + query = ( + "DELETE FROM %(pdu_back)s WHERE EXISTS (" + "SELECT 1 FROM %(pdus)s AS pdus " + "WHERE " + "%(pdu_back)s.pdu_id = pdus.pdu_id " + "AND %(pdu_back)s.origin = pdus.origin " + "AND not pdus.outlier " + ")" + ) % { + "pdu_back": PduBackwardExtremitiesTable.table_name, + "pdus": PdusTable.table_name, + } + txn.execute(query) + + +class StatePduStore(SQLBaseStore): + """A collection of queries for handling state PDUs. + """ + + def persist_state(self, prev_pdus, **cols): + """Inserts a state PDU into the database + + Args: + txn, + prev_pdus (list) + **cols: The columns to insert into the PdusTable and StatePdusTable + """ + + return self._db_pool.runInteraction( + self._persist_state, prev_pdus, cols + ) + + def _persist_state(self, txn, prev_pdus, cols): + pdu_entry = PdusTable.EntryType( + **{k: cols.get(k, None) for k in PdusTable.fields} + ) + state_entry = StatePdusTable.EntryType( + **{k: cols.get(k, None) for k in StatePdusTable.fields} + ) + + logger.debug("Inserting pdu: %s", repr(pdu_entry)) + logger.debug("Inserting state: %s", repr(state_entry)) + + txn.execute(PdusTable.insert_statement(), pdu_entry) + txn.execute(StatePdusTable.insert_statement(), state_entry) + + self._handle_prev_pdus( + txn, + pdu_entry.outlier, pdu_entry.pdu_id, pdu_entry.origin, prev_pdus, + pdu_entry.context + ) + + def get_unresolved_state_tree(self, new_state_pdu): + return self._db_pool.runInteraction( + self._get_unresolved_state_tree, new_state_pdu + ) + + @log_function + def _get_unresolved_state_tree(self, txn, new_pdu): + current = self._get_current_interaction( + txn, + new_pdu.context, new_pdu.pdu_type, new_pdu.state_key + ) + + ReturnType = namedtuple( + "StateReturnType", ["new_branch", "current_branch"] + ) + return_value = ReturnType([new_pdu], []) + + if not current: + logger.debug("get_unresolved_state_tree No current state.") + return return_value + + return_value.current_branch.append(current) + + enum_branches = self._enumerate_state_branches( + txn, new_pdu, current + ) + + for branch, prev_state, state in enum_branches: + if state: + return_value[branch].append(state) + else: + break + + return return_value + + def update_current_state(self, pdu_id, origin, context, pdu_type, + state_key): + return self._db_pool.runInteraction( + self._update_current_state, + pdu_id, origin, context, pdu_type, state_key + ) + + def _update_current_state(self, txn, pdu_id, origin, context, pdu_type, + state_key): + query = ( + "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" + ) % { + "curr": CurrentStateTable.table_name, + "fields": CurrentStateTable.get_fields_string(), + "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) + } + + query_args = CurrentStateTable.EntryType( + pdu_id=pdu_id, + origin=origin, + context=context, + pdu_type=pdu_type, + state_key=state_key + ) + + txn.execute(query, query_args) + + def get_current_state(self, context, pdu_type, state_key): + """For a given context, pdu_type, state_key 3-tuple, return what is + currently considered the current state. + + Args: + txn + context (str) + pdu_type (str) + state_key (str) + + Returns: + PduEntry + """ + + return self._db_pool.runInteraction( + self._get_current_state, context, pdu_type, state_key + ) + + def _get_current_state(self, txn, context, pdu_type, state_key): + return self._get_current_interaction(txn, context, pdu_type, state_key) + + def _get_current_interaction(self, txn, context, pdu_type, state_key): + logger.debug( + "_get_current_interaction %s %s %s", + context, pdu_type, state_key + ) + + fields = _pdu_state_joiner.get_fields( + PdusTable="p", StatePdusTable="s") + + current_query = ( + "SELECT %(fields)s FROM %(state)s as s " + "INNER JOIN %(pdus)s as p " + "ON s.pdu_id = p.pdu_id AND s.origin = p.origin " + "INNER JOIN %(curr)s as c " + "ON s.pdu_id = c.pdu_id AND s.origin = c.origin " + "WHERE s.context = ? AND s.pdu_type = ? AND s.state_key = ? " + ) % { + "fields": fields, + "curr": CurrentStateTable.table_name, + "state": StatePdusTable.table_name, + "pdus": PdusTable.table_name, + } + + txn.execute( + current_query, + (context, pdu_type, state_key) + ) + + row = txn.fetchone() + + result = PduEntry(*row) if row else None + + if not result: + logger.debug("_get_current_interaction not found") + else: + logger.debug( + "_get_current_interaction found %s %s", + result.pdu_id, result.origin + ) + + return result + + def get_next_missing_pdu(self, new_pdu): + """When we get a new state pdu we need to check whether we need to do + any conflict resolution, if we do then we need to check if we need + to go back and request some more state pdus that we haven't seen yet. + + Args: + txn + new_pdu + + Returns: + PduIdTuple: A pdu that we are missing, or None if we have all the + pdus required to do the conflict resolution. + """ + return self._db_pool.runInteraction( + self._get_next_missing_pdu, new_pdu + ) + + def _get_next_missing_pdu(self, txn, new_pdu): + logger.debug( + "get_next_missing_pdu %s %s", + new_pdu.pdu_id, new_pdu.origin + ) + + current = self._get_current_interaction( + txn, + new_pdu.context, new_pdu.pdu_type, new_pdu.state_key + ) + + if (not current or not current.prev_state_id + or not current.prev_state_origin): + return None + + # Oh look, it's a straight clobber, so wooooo almost no-op. + if (new_pdu.prev_state_id == current.pdu_id + and new_pdu.prev_state_origin == current.origin): + return None + + enum_branches = self._enumerate_state_branches(txn, new_pdu, current) + for branch, prev_state, state in enum_branches: + if not state: + return PduIdTuple( + prev_state.prev_state_id, + prev_state.prev_state_origin + ) + + return None + + def handle_new_state(self, new_pdu): + """Actually perform conflict resolution on the new_pdu on the + assumption we have all the pdus required to perform it. + + Args: + new_pdu + + Returns: + bool: True if the new_pdu clobbered the current state, False if not + """ + return self._db_pool.runInteraction( + self._handle_new_state, new_pdu + ) + + def _handle_new_state(self, txn, new_pdu): + logger.debug( + "handle_new_state %s %s", + new_pdu.pdu_id, new_pdu.origin + ) + + current = self._get_current_interaction( + txn, + new_pdu.context, new_pdu.pdu_type, new_pdu.state_key + ) + + is_current = False + + if (not current or not current.prev_state_id + or not current.prev_state_origin): + # Oh, we don't have any state for this yet. + is_current = True + elif (current.pdu_id == new_pdu.prev_state_id + and current.origin == new_pdu.prev_state_origin): + # Oh! A direct clobber. Just do it. + is_current = True + else: + ## + # Ok, now loop through until we get to a common ancestor. + max_new = int(new_pdu.power_level) + max_current = int(current.power_level) + + enum_branches = self._enumerate_state_branches( + txn, new_pdu, current + ) + for branch, prev_state, state in enum_branches: + if not state: + raise RuntimeError( + "Could not find state_pdu %s %s" % + ( + prev_state.prev_state_id, + prev_state.prev_state_origin + ) + ) + + if branch == 0: + max_new = max(int(state.depth), max_new) + else: + max_current = max(int(state.depth), max_current) + + is_current = max_new > max_current + + if is_current: + logger.debug("handle_new_state make current") + + # Right, this is a new thing, so woo, just insert it. + txn.execute( + "INSERT OR REPLACE INTO %(curr)s (%(fields)s) VALUES (%(qs)s)" + % { + "curr": CurrentStateTable.table_name, + "fields": CurrentStateTable.get_fields_string(), + "qs": ", ".join(["?"] * len(CurrentStateTable.fields)) + }, + CurrentStateTable.EntryType( + *(new_pdu.__dict__[k] for k in CurrentStateTable.fields) + ) + ) + else: + logger.debug("handle_new_state not current") + + logger.debug("handle_new_state done") + + return is_current + + @classmethod + @log_function + def _enumerate_state_branches(cls, txn, pdu_a, pdu_b): + branch_a = pdu_a + branch_b = pdu_b + + get_query = ( + "SELECT %(fields)s FROM %(pdus)s as p " + "LEFT JOIN %(state)s as s " + "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " + "WHERE p.pdu_id = ? AND p.origin = ? " + ) % { + "fields": _pdu_state_joiner.get_fields( + PdusTable="p", StatePdusTable="s"), + "pdus": PdusTable.table_name, + "state": StatePdusTable.table_name, + } + + while True: + if (branch_a.pdu_id == branch_b.pdu_id + and branch_a.origin == branch_b.origin): + # Woo! We found a common ancestor + logger.debug("_enumerate_state_branches Found common ancestor") + break + + do_branch_a = ( + hasattr(branch_a, "prev_state_id") and + branch_a.prev_state_id + ) + + do_branch_b = ( + hasattr(branch_b, "prev_state_id") and + branch_b.prev_state_id + ) + + logger.debug( + "do_branch_a=%s, do_branch_b=%s", + do_branch_a, do_branch_b + ) + + if do_branch_a and do_branch_b: + do_branch_a = int(branch_a.depth) > int(branch_b.depth) + + if do_branch_a: + pdu_tuple = PduIdTuple( + branch_a.prev_state_id, + branch_a.prev_state_origin + ) + + logger.debug("getting branch_a prev %s", pdu_tuple) + txn.execute(get_query, pdu_tuple) + + prev_branch = branch_a + + res = txn.fetchone() + branch_a = PduEntry(*res) if res else None + + logger.debug("branch_a=%s", branch_a) + + yield (0, prev_branch, branch_a) + + if not branch_a: + break + elif do_branch_b: + pdu_tuple = PduIdTuple( + branch_b.prev_state_id, + branch_b.prev_state_origin + ) + txn.execute(get_query, pdu_tuple) + + logger.debug("getting branch_b prev %s", pdu_tuple) + + prev_branch = branch_b + + res = txn.fetchone() + branch_b = PduEntry(*res) if res else None + + logger.debug("branch_b=%s", branch_b) + + yield (1, prev_branch, branch_b) + + if not branch_b: + break + else: + break + + +class PdusTable(Table): + table_name = "pdus" + + fields = [ + "pdu_id", + "origin", + "context", + "pdu_type", + "ts", + "depth", + "is_state", + "content_json", + "unrecognized_keys", + "outlier", + "have_processed", + ] + + EntryType = namedtuple("PdusEntry", fields) + + +class PduDestinationsTable(Table): + table_name = "pdu_destinations" + + fields = [ + "pdu_id", + "origin", + "destination", + "delivered_ts", + ] + + EntryType = namedtuple("PduDestinationsEntry", fields) + + +class PduEdgesTable(Table): + table_name = "pdu_edges" + + fields = [ + "pdu_id", + "origin", + "prev_pdu_id", + "prev_origin", + "context" + ] + + EntryType = namedtuple("PduEdgesEntry", fields) + + +class PduForwardExtremitiesTable(Table): + table_name = "pdu_forward_extremities" + + fields = [ + "pdu_id", + "origin", + "context", + ] + + EntryType = namedtuple("PduForwardExtremitiesEntry", fields) + + +class PduBackwardExtremitiesTable(Table): + table_name = "pdu_backward_extremities" + + fields = [ + "pdu_id", + "origin", + "context", + ] + + EntryType = namedtuple("PduBackwardExtremitiesEntry", fields) + + +class ContextDepthTable(Table): + table_name = "context_depth" + + fields = [ + "context", + "min_depth", + ] + + EntryType = namedtuple("ContextDepthEntry", fields) + + +class StatePdusTable(Table): + table_name = "state_pdus" + + fields = [ + "pdu_id", + "origin", + "context", + "pdu_type", + "state_key", + "power_level", + "prev_state_id", + "prev_state_origin", + ] + + EntryType = namedtuple("StatePdusEntry", fields) + + +class CurrentStateTable(Table): + table_name = "current_state" + + fields = [ + "pdu_id", + "origin", + "context", + "pdu_type", + "state_key", + ] + + EntryType = namedtuple("CurrentStateEntry", fields) + +_pdu_state_joiner = JoinHelper(PdusTable, StatePdusTable) + + +# TODO: These should probably be put somewhere more sensible +PduIdTuple = namedtuple("PduIdTuple", ("pdu_id", "origin")) + +PduEntry = _pdu_state_joiner.EntryType +""" We are always interested in the join of the PdusTable and StatePdusTable, +rather than just the PdusTable. + +This does not include a prev_pdus key. +""" + +PduTuple = namedtuple( + "PduTuple", + ("pdu_entry", "prev_pdu_list") +) +""" This is a tuple of a `PduEntry` and a list of `PduIdTuple` that represent +the `prev_pdus` key of a PDU. +""" diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py new file mode 100644 index 0000000000..e57ddaf149 --- /dev/null +++ b/synapse/storage/presence.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore + + +class PresenceStore(SQLBaseStore): + def create_presence(self, user_localpart): + return self._simple_insert( + table="presence", + values={"user_id": user_localpart}, + ) + + def has_presence_state(self, user_localpart): + return self._simple_select_one( + table="presence", + keyvalues={"user_id": user_localpart}, + retcols=["user_id"], + allow_none=True, + ) + + def get_presence_state(self, user_localpart): + return self._simple_select_one( + table="presence", + keyvalues={"user_id": user_localpart}, + retcols=["state", "status_msg"], + ) + + def set_presence_state(self, user_localpart, new_state): + return self._simple_update_one( + table="presence", + keyvalues={"user_id": user_localpart}, + updatevalues={"state": new_state["state"], + "status_msg": new_state["status_msg"]}, + retcols=["state"], + ) + + def allow_presence_visible(self, observed_localpart, observer_userid): + return self._simple_insert( + table="presence_allow_inbound", + values={"observed_user_id": observed_localpart, + "observer_user_id": observer_userid}, + ) + + def disallow_presence_visible(self, observed_localpart, observer_userid): + return self._simple_delete_one( + table="presence_allow_inbound", + keyvalues={"observed_user_id": observed_localpart, + "observer_user_id": observer_userid}, + ) + + def is_presence_visible(self, observed_localpart, observer_userid): + return self._simple_select_one( + table="presence_allow_inbound", + keyvalues={"observed_user_id": observed_localpart, + "observer_user_id": observer_userid}, + allow_none=True, + ) + + def add_presence_list_pending(self, observer_localpart, observed_userid): + return self._simple_insert( + table="presence_list", + values={"user_id": observer_localpart, + "observed_user_id": observed_userid, + "accepted": False}, + ) + + def set_presence_list_accepted(self, observer_localpart, observed_userid): + return self._simple_update_one( + table="presence_list", + keyvalues={"user_id": observer_localpart, + "observed_user_id": observed_userid}, + updatevalues={"accepted": True}, + ) + + def get_presence_list(self, observer_localpart, accepted=None): + keyvalues = {"user_id": observer_localpart} + if accepted is not None: + keyvalues["accepted"] = accepted + + return self._simple_select_list( + table="presence_list", + keyvalues=keyvalues, + retcols=["observed_user_id", "accepted"], + ) + + def del_presence_list(self, observer_localpart, observed_userid): + return self._simple_delete_one( + table="presence_list", + keyvalues={"user_id": observer_localpart, + "observed_user_id": observed_userid}, + ) diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py new file mode 100644 index 0000000000..d2f24930c1 --- /dev/null +++ b/synapse/storage/profile.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore + + +class ProfileStore(SQLBaseStore): + def create_profile(self, user_localpart): + return self._simple_insert( + table="profiles", + values={"user_id": user_localpart}, + ) + + def get_profile_displayname(self, user_localpart): + return self._simple_select_one_onecol( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcol="displayname", + ) + + def set_profile_displayname(self, user_localpart, new_displayname): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"displayname": new_displayname}, + ) + + def get_profile_avatar_url(self, user_localpart): + return self._simple_select_one_onecol( + table="profiles", + keyvalues={"user_id": user_localpart}, + retcol="avatar_url", + ) + + def set_profile_avatar_url(self, user_localpart, new_avatar_url): + return self._simple_update_one( + table="profiles", + keyvalues={"user_id": user_localpart}, + updatevalues={"avatar_url": new_avatar_url}, + ) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py new file mode 100644 index 0000000000..4a970dd546 --- /dev/null +++ b/synapse/storage/registration.py @@ -0,0 +1,113 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from sqlite3 import IntegrityError + +from synapse.api.errors import StoreError + +from ._base import SQLBaseStore + + +class RegistrationStore(SQLBaseStore): + + def __init__(self, hs): + super(RegistrationStore, self).__init__(hs) + + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def add_access_token_to_user(self, user_id, token): + """Adds an access token for the given user. + + Args: + user_id (str): The user ID. + token (str): The new access token to add. + Raises: + StoreError if there was a problem adding this. + """ + row = yield self._simple_select_one("users", {"name": user_id}, ["id"]) + if not row: + raise StoreError(400, "Bad user ID supplied.") + row_id = row["id"] + yield self._simple_insert( + "access_tokens", + { + "user_id": row_id, + "token": token + } + ) + + @defer.inlineCallbacks + def register(self, user_id, token, password_hash): + """Attempts to register an account. + + Args: + user_id (str): The desired user ID to register. + token (str): The desired access token to use for this user. + password_hash (str): Optional. The password hash for this user. + Raises: + StoreError if the user_id could not be registered. + """ + yield self._db_pool.runInteraction(self._register, user_id, token, + password_hash) + + def _register(self, txn, user_id, token, password_hash): + now = int(self.clock.time()) + + try: + txn.execute("INSERT INTO users(name, password_hash, creation_ts) " + "VALUES (?,?,?)", + [user_id, password_hash, now]) + except IntegrityError: + raise StoreError(400, "User ID already taken.") + + # it's possible for this to get a conflict, but only for a single user + # since tokens are namespaced based on their user ID + txn.execute("INSERT INTO access_tokens(user_id, token) " + + "VALUES (?,?)", [txn.lastrowid, token]) + + def get_user_by_id(self, user_id): + query = ("SELECT users.name, users.password_hash FROM users " + "WHERE users.name = ?") + return self._execute( + self.cursor_to_dict, + query, user_id + ) + + @defer.inlineCallbacks + def get_user_by_token(self, token): + """Get a user from the given access token. + + Args: + token (str): The access token of a user. + Returns: + str: The user ID of the user. + Raises: + StoreError if no user was found. + """ + user_id = yield self._db_pool.runInteraction(self._query_for_auth, + token) + defer.returnValue(user_id) + + def _query_for_auth(self, txn, token): + txn.execute("SELECT users.name FROM access_tokens LEFT JOIN users" + + " ON users.id = access_tokens.user_id WHERE token = ?", + [token]) + row = txn.fetchone() + if row: + return row[0] + + raise StoreError(404, "Token not found.") diff --git a/synapse/storage/room.py b/synapse/storage/room.py new file mode 100644 index 0000000000..174cbcf3d8 --- /dev/null +++ b/synapse/storage/room.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from sqlite3 import IntegrityError + +from synapse.api.errors import StoreError +from synapse.api.events.room import RoomTopicEvent + +from ._base import SQLBaseStore, Table + +import collections +import json +import logging + +logger = logging.getLogger(__name__) + + +class RoomStore(SQLBaseStore): + + @defer.inlineCallbacks + def store_room(self, room_id, room_creator_user_id, is_public): + """Stores a room. + + Args: + room_id (str): The desired room ID, can be None. + room_creator_user_id (str): The user ID of the room creator. + is_public (bool): True to indicate that this room should appear in + public room lists. + Raises: + StoreError if the room could not be stored. + """ + try: + yield self._simple_insert(RoomsTable.table_name, dict( + room_id=room_id, + creator=room_creator_user_id, + is_public=is_public + )) + except IntegrityError: + raise StoreError(409, "Room ID in use.") + except Exception as e: + logger.error("store_room with room_id=%s failed: %s", room_id, e) + raise StoreError(500, "Problem creating room.") + + def store_room_config(self, room_id, visibility): + return self._simple_update_one( + table=RoomsTable.table_name, + keyvalues={"room_id": room_id}, + updatevalues={"is_public": visibility} + ) + + def get_room(self, room_id): + """Retrieve a room. + + Args: + room_id (str): The ID of the room to retrieve. + Returns: + A namedtuple containing the room information, or an empty list. + """ + query = RoomsTable.select_statement("room_id=?") + return self._execute( + RoomsTable.decode_single_result, query, room_id, + ) + + @defer.inlineCallbacks + def get_rooms(self, is_public, with_topics): + """Retrieve a list of all public rooms. + + Args: + is_public (bool): True if the rooms returned should be public. + with_topics (bool): True to include the current topic for the room + in the response. + Returns: + A list of room dicts containing at least a "room_id" key, and a + "topic" key if one is set and with_topic=True. + """ + room_data_type = RoomTopicEvent.TYPE + public = 1 if is_public else 0 + + latest_topic = ("SELECT max(room_data.id) FROM room_data WHERE " + + "room_data.type = ? GROUP BY room_id") + + query = ("SELECT rooms.*, room_data.content FROM rooms LEFT JOIN " + + "room_data ON rooms.room_id = room_data.room_id WHERE " + + "(room_data.id IN (" + latest_topic + ") " + + "OR room_data.id IS NULL) AND rooms.is_public = ?") + + res = yield self._execute( + self.cursor_to_dict, query, room_data_type, public + ) + + # return only the keys the specification expects + ret_keys = ["room_id", "topic"] + + # extract topic from the json (icky) FIXME + for i, room_row in enumerate(res): + try: + content_json = json.loads(room_row["content"]) + room_row["topic"] = content_json["topic"] + except: + pass # no topic set + # filter the dict based on ret_keys + res[i] = {k: v for k, v in room_row.iteritems() if k in ret_keys} + + defer.returnValue(res) + + +class RoomsTable(Table): + table_name = "rooms" + + fields = [ + "room_id", + "is_public", + "creator" + ] + + EntryType = collections.namedtuple("RoomEntry", fields) diff --git a/synapse/storage/roomdata.py b/synapse/storage/roomdata.py new file mode 100644 index 0000000000..781d477931 --- /dev/null +++ b/synapse/storage/roomdata.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore, Table + +import collections +import json + + +class RoomDataStore(SQLBaseStore): + + """Provides various CRUD operations for Room Events. """ + + def get_room_data(self, room_id, etype, state_key=""): + """Retrieve the data stored under this type and state_key. + + Args: + room_id (str) + etype (str) + state_key (str) + Returns: + namedtuple: Or None if nothing exists at this path. + """ + query = RoomDataTable.select_statement( + "room_id = ? AND type = ? AND state_key = ? " + "ORDER BY id DESC LIMIT 1" + ) + return self._execute( + RoomDataTable.decode_single_result, + query, room_id, etype, state_key, + ) + + def store_room_data(self, room_id, etype, state_key="", content=None): + """Stores room specific data. + + Args: + room_id (str) + etype (str) + state_key (str) + data (str)- The data to store for this path in JSON. + Returns: + The store ID for this data. + """ + return self._simple_insert(RoomDataTable.table_name, dict( + etype=etype, + state_key=state_key, + room_id=room_id, + content=content, + )) + + def get_max_room_data_id(self): + return self._simple_max_id(RoomDataTable.table_name) + + +class RoomDataTable(Table): + table_name = "room_data" + + fields = [ + "id", + "room_id", + "type", + "state_key", + "content" + ] + + class EntryType(collections.namedtuple("RoomDataEntry", fields)): + + def as_event(self, event_factory): + return event_factory.create_event( + etype=self.type, + room_id=self.room_id, + content=json.loads(self.content), + ) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py new file mode 100644 index 0000000000..e6e7617797 --- /dev/null +++ b/synapse/storage/roommember.py @@ -0,0 +1,171 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.types import UserID +from synapse.api.constants import Membership +from synapse.api.events.room import RoomMemberEvent + +from ._base import SQLBaseStore, Table + + +import collections +import json +import logging + +logger = logging.getLogger(__name__) + + +class RoomMemberStore(SQLBaseStore): + + def get_room_member(self, user_id, room_id): + """Retrieve the current state of a room member. + + Args: + user_id (str): The member's user ID. + room_id (str): The room the member is in. + Returns: + namedtuple: The room member from the database, or None if this + member does not exist. + """ + query = RoomMemberTable.select_statement( + "room_id = ? AND user_id = ? ORDER BY id DESC LIMIT 1") + return self._execute( + RoomMemberTable.decode_single_result, + query, room_id, user_id, + ) + + def store_room_member(self, user_id, sender, room_id, membership, content): + """Store a room member in the database. + + Args: + user_id (str): The member's user ID. + room_id (str): The room in relation to the member. + membership (synapse.api.constants.Membership): The new membership + state. + content (dict): The content of the membership (JSON). + """ + content_json = json.dumps(content) + return self._simple_insert(RoomMemberTable.table_name, dict( + user_id=user_id, + sender=sender, + room_id=room_id, + membership=membership, + content=content_json, + )) + + @defer.inlineCallbacks + def get_room_members(self, room_id, membership=None): + """Retrieve the current room member list for a room. + + Args: + room_id (str): The room to get the list of members. + membership (synapse.api.constants.Membership): The filter to apply + to this list, or None to return all members with some state + associated with this room. + Returns: + list of namedtuples representing the members in this room. + """ + query = RoomMemberTable.select_statement( + "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name + + " WHERE room_id = ? GROUP BY user_id)" + ) + res = yield self._execute( + RoomMemberTable.decode_results, query, room_id, + ) + # strip memberships which don't match + if membership: + res = [entry for entry in res if entry.membership == membership] + defer.returnValue(res) + + def get_rooms_for_user_where_membership_is(self, user_id, membership_list): + """ Get all the rooms for this user where the membership for this user + matches one in the membership list. + + Args: + user_id (str): The user ID. + membership_list (list): A list of synapse.api.constants.Membership + values which the user must be in. + Returns: + A list of dicts with "room_id" and "membership" keys. + """ + if not membership_list: + return defer.succeed(None) + + args = [user_id] + membership_placeholder = ["membership=?"] * len(membership_list) + where_membership = "(" + " OR ".join(membership_placeholder) + ")" + for membership in membership_list: + args.append(membership) + + query = ("SELECT room_id, membership FROM room_memberships" + + " WHERE user_id=? AND " + where_membership + + " GROUP BY room_id ORDER BY id DESC") + return self._execute( + self.cursor_to_dict, query, *args + ) + + @defer.inlineCallbacks + def get_joined_hosts_for_room(self, room_id): + query = RoomMemberTable.select_statement( + "id IN (SELECT MAX(id) FROM " + RoomMemberTable.table_name + + " WHERE room_id = ? GROUP BY user_id)" + ) + + res = yield self._execute( + RoomMemberTable.decode_results, query, room_id, + ) + + def host_from_user_id_string(user_id): + domain = UserID.from_string(entry.user_id, self.hs).domain + return domain + + # strip memberships which don't match + hosts = [ + host_from_user_id_string(entry.user_id) + for entry in res + if entry.membership == Membership.JOIN + ] + + logger.debug("Returning hosts: %s from results: %s", hosts, res) + + defer.returnValue(hosts) + + def get_max_room_member_id(self): + return self._simple_max_id(RoomMemberTable.table_name) + + +class RoomMemberTable(Table): + table_name = "room_memberships" + + fields = [ + "id", + "user_id", + "sender", + "room_id", + "membership", + "content" + ] + + class EntryType(collections.namedtuple("RoomMemberEntry", fields)): + + def as_event(self, event_factory): + return event_factory.create_event( + etype=RoomMemberEvent.TYPE, + room_id=self.room_id, + target_user_id=self.user_id, + user_id=self.sender, + content=json.loads(self.content), + ) diff --git a/synapse/storage/schema/edge_pdus.sql b/synapse/storage/schema/edge_pdus.sql new file mode 100644 index 0000000000..17b3c52f0d --- /dev/null +++ b/synapse/storage/schema/edge_pdus.sql @@ -0,0 +1,31 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS context_edge_pdus( + id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this + pdu_id TEXT, + origin TEXT, + context TEXT, + CONSTRAINT context_edge_pdu_id_origin UNIQUE (pdu_id, origin) +); + +CREATE TABLE IF NOT EXISTS origin_edge_pdus( + id INTEGER PRIMARY KEY AUTOINCREMENT, -- twistar requires this + pdu_id TEXT, + origin TEXT, + CONSTRAINT origin_edge_pdu_id_origin UNIQUE (pdu_id, origin) +); + +CREATE INDEX IF NOT EXISTS context_edge_pdu_id ON context_edge_pdus(pdu_id, origin); +CREATE INDEX IF NOT EXISTS origin_edge_pdu_id ON origin_edge_pdus(pdu_id, origin); diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql new file mode 100644 index 0000000000..77096546b2 --- /dev/null +++ b/synapse/storage/schema/im.sql @@ -0,0 +1,54 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS rooms( + room_id TEXT PRIMARY KEY NOT NULL, + is_public INTEGER, + creator TEXT +); + +CREATE TABLE IF NOT EXISTS room_memberships( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT NOT NULL, -- no foreign key to users table, it could be an id belonging to another home server + sender TEXT NOT NULL, + room_id TEXT NOT NULL, + membership TEXT NOT NULL, + content TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS messages( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id TEXT, + room_id TEXT, + msg_id TEXT, + content TEXT +); + +CREATE TABLE IF NOT EXISTS feedback( + id INTEGER PRIMARY KEY AUTOINCREMENT, + content TEXT, + feedback_type TEXT, + fb_sender_id TEXT, + msg_id TEXT, + room_id TEXT, + msg_sender_id TEXT +); + +CREATE TABLE IF NOT EXISTS room_data( + id INTEGER PRIMARY KEY AUTOINCREMENT, + room_id TEXT NOT NULL, + type TEXT NOT NULL, + state_key TEXT NOT NULL, + content TEXT +); diff --git a/synapse/storage/schema/pdu.sql b/synapse/storage/schema/pdu.sql new file mode 100644 index 0000000000..ca3de005e9 --- /dev/null +++ b/synapse/storage/schema/pdu.sql @@ -0,0 +1,106 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- Stores pdus and their content +CREATE TABLE IF NOT EXISTS pdus( + pdu_id TEXT, + origin TEXT, + context TEXT, + pdu_type TEXT, + ts INTEGER, + depth INTEGER DEFAULT 0 NOT NULL, + is_state BOOL, + content_json TEXT, + unrecognized_keys TEXT, + outlier BOOL NOT NULL, + have_processed BOOL, + CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) +); + +-- Stores what the current state pdu is for a given (context, pdu_type, key) tuple +CREATE TABLE IF NOT EXISTS state_pdus( + pdu_id TEXT, + origin TEXT, + context TEXT, + pdu_type TEXT, + state_key TEXT, + power_level TEXT, + prev_state_id TEXT, + prev_state_origin TEXT, + CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) + CONSTRAINT prev_pdu_id_origin UNIQUE (prev_state_id, prev_state_origin) +); + +CREATE TABLE IF NOT EXISTS current_state( + pdu_id TEXT, + origin TEXT, + context TEXT, + pdu_type TEXT, + state_key TEXT, + CONSTRAINT pdu_id_origin UNIQUE (pdu_id, origin) + CONSTRAINT uniqueness UNIQUE (context, pdu_type, state_key) ON CONFLICT REPLACE +); + +-- Stores where each pdu we want to send should be sent and the delivery status. +create TABLE IF NOT EXISTS pdu_destinations( + pdu_id TEXT, + origin TEXT, + destination TEXT, + delivered_ts INTEGER DEFAULT 0, -- or 0 if not delivered + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, destination) ON CONFLICT REPLACE +); + +CREATE TABLE IF NOT EXISTS pdu_forward_extremities( + pdu_id TEXT, + origin TEXT, + context TEXT, + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE +); + +CREATE TABLE IF NOT EXISTS pdu_backward_extremities( + pdu_id TEXT, + origin TEXT, + context TEXT, + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, context) ON CONFLICT REPLACE +); + +CREATE TABLE IF NOT EXISTS pdu_edges( + pdu_id TEXT, + origin TEXT, + prev_pdu_id TEXT, + prev_origin TEXT, + context TEXT, + CONSTRAINT uniqueness UNIQUE (pdu_id, origin, prev_pdu_id, prev_origin, context) +); + +CREATE TABLE IF NOT EXISTS context_depth( + context TEXT, + min_depth INTEGER, + CONSTRAINT uniqueness UNIQUE (context) +); + +CREATE INDEX IF NOT EXISTS context_depth_context ON context_depth(context); + + +CREATE INDEX IF NOT EXISTS pdu_id ON pdus(pdu_id, origin); + +CREATE INDEX IF NOT EXISTS dests_id ON pdu_destinations (pdu_id, origin); +-- CREATE INDEX IF NOT EXISTS dests ON pdu_destinations (destination); + +CREATE INDEX IF NOT EXISTS pdu_extrem_context ON pdu_forward_extremities(context); +CREATE INDEX IF NOT EXISTS pdu_extrem_id ON pdu_forward_extremities(pdu_id, origin); + +CREATE INDEX IF NOT EXISTS pdu_edges_id ON pdu_edges(pdu_id, origin); + +CREATE INDEX IF NOT EXISTS pdu_b_extrem_context ON pdu_backward_extremities(context); diff --git a/synapse/storage/schema/presence.sql b/synapse/storage/schema/presence.sql new file mode 100644 index 0000000000..b22e3ba863 --- /dev/null +++ b/synapse/storage/schema/presence.sql @@ -0,0 +1,37 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS presence( + user_id INTEGER NOT NULL, + state INTEGER, + status_msg TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) +); + +-- For each of /my/ users which possibly-remote users are allowed to see their +-- presence state +CREATE TABLE IF NOT EXISTS presence_allow_inbound( + observed_user_id INTEGER NOT NULL, + observer_user_id TEXT, -- a UserID, + FOREIGN KEY(observed_user_id) REFERENCES users(id) +); + +-- For each of /my/ users (watcher), which possibly-remote users are they +-- watching? +CREATE TABLE IF NOT EXISTS presence_list( + user_id INTEGER NOT NULL, + observed_user_id TEXT, -- a UserID, + accepted BOOLEAN, + FOREIGN KEY(user_id) REFERENCES users(id) +); diff --git a/synapse/storage/schema/profiles.sql b/synapse/storage/schema/profiles.sql new file mode 100644 index 0000000000..1092d7672c --- /dev/null +++ b/synapse/storage/schema/profiles.sql @@ -0,0 +1,20 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS profiles( + user_id INTEGER NOT NULL, + displayname TEXT, + avatar_url TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) +); diff --git a/synapse/storage/schema/room_aliases.sql b/synapse/storage/schema/room_aliases.sql new file mode 100644 index 0000000000..71a8b90e4d --- /dev/null +++ b/synapse/storage/schema/room_aliases.sql @@ -0,0 +1,12 @@ +CREATE TABLE IF NOT EXISTS room_aliases( + room_alias TEXT NOT NULL, + room_id TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS room_alias_servers( + room_alias TEXT NOT NULL, + server TEXT NOT NULL +); + + + diff --git a/synapse/storage/schema/transactions.sql b/synapse/storage/schema/transactions.sql new file mode 100644 index 0000000000..4b1a2368f6 --- /dev/null +++ b/synapse/storage/schema/transactions.sql @@ -0,0 +1,61 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- Stores what transaction ids we have received and what our response was +CREATE TABLE IF NOT EXISTS received_transactions( + transaction_id TEXT, + origin TEXT, + ts INTEGER, + response_code INTEGER, + response_json TEXT, + has_been_referenced BOOL default 0, -- Whether thishas been referenced by a prev_tx + CONSTRAINT uniquesss UNIQUE (transaction_id, origin) ON CONFLICT REPLACE +); + +CREATE UNIQUE INDEX IF NOT EXISTS transactions_txid ON received_transactions(transaction_id, origin); +CREATE INDEX IF NOT EXISTS transactions_have_ref ON received_transactions(origin, has_been_referenced);-- WHERE has_been_referenced = 0; + + +-- Stores what transactions we've sent, what their response was (if we got one) and whether we have +-- since referenced the transaction in another outgoing transaction +CREATE TABLE IF NOT EXISTS sent_transactions( + id INTEGER PRIMARY KEY AUTOINCREMENT, -- This is used to apply insertion ordering + transaction_id TEXT, + destination TEXT, + response_code INTEGER DEFAULT 0, + response_json TEXT, + ts INTEGER +); + +CREATE INDEX IF NOT EXISTS sent_transaction_dest ON sent_transactions(destination); +CREATE INDEX IF NOT EXISTS sent_transaction_dest_referenced ON sent_transactions( + destination +); +-- So that we can do an efficient look up of all transactions that have yet to be successfully +-- sent. +CREATE INDEX IF NOT EXISTS sent_transaction_sent ON sent_transactions(response_code); + + +-- For sent transactions only. +CREATE TABLE IF NOT EXISTS transaction_id_to_pdu( + transaction_id INTEGER, + destination TEXT, + pdu_id TEXT, + pdu_origin TEXT +); + +CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_tx ON transaction_id_to_pdu(transaction_id, destination); +CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_dest ON transaction_id_to_pdu(destination); +CREATE INDEX IF NOT EXISTS transaction_id_to_pdu_index ON transaction_id_to_pdu(transaction_id, destination); + diff --git a/synapse/storage/schema/users.sql b/synapse/storage/schema/users.sql new file mode 100644 index 0000000000..46b60297cb --- /dev/null +++ b/synapse/storage/schema/users.sql @@ -0,0 +1,31 @@ +/* Copyright 2014 matrix.org + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS users( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT, + password_hash TEXT, + creation_ts INTEGER, + UNIQUE(name) ON CONFLICT ROLLBACK +); + +CREATE TABLE IF NOT EXISTS access_tokens( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_id INTEGER NOT NULL, + device_id TEXT, + token TEXT NOT NULL, + last_used INTEGER, + FOREIGN KEY(user_id) REFERENCES users(id), + UNIQUE(token) ON CONFLICT ROLLBACK +); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py new file mode 100644 index 0000000000..c3b1bfeb32 --- /dev/null +++ b/synapse/storage/stream.py @@ -0,0 +1,282 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore +from .message import MessagesTable +from .feedback import FeedbackTable +from .roomdata import RoomDataTable +from .roommember import RoomMemberTable + +import json +import logging + +logger = logging.getLogger(__name__) + + +class StreamStore(SQLBaseStore): + + def get_message_stream(self, user_id, from_key, to_key, room_id, limit=0, + with_feedback=False): + """Get all messages for this user between the given keys. + + Args: + user_id (str): The user who is requesting messages. + from_key (int): The ID to start returning results from (exclusive). + to_key (int): The ID to stop returning results (exclusive). + room_id (str): Gets messages only for this room. Can be None, in + which case all room messages will be returned. + Returns: + A tuple of rows (list of namedtuples), new_id(int) + """ + if with_feedback and room_id: # with fb MUST specify a room ID + return self._db_pool.runInteraction( + self._get_message_rows_with_feedback, + user_id, from_key, to_key, room_id, limit + ) + else: + return self._db_pool.runInteraction( + self._get_message_rows, + user_id, from_key, to_key, room_id, limit + ) + + def _get_message_rows(self, txn, user_id, from_pkey, to_pkey, room_id, + limit): + # work out which rooms this user is joined in on and join them with + # the room id on the messages table, bounded by the specified pkeys + + # get all messages where the *current* membership state is 'join' for + # this user in that room. + query = ("SELECT messages.* FROM messages WHERE ? IN" + + " (SELECT membership from room_memberships WHERE user_id=?" + + " AND room_id = messages.room_id ORDER BY id DESC LIMIT 1)") + query_args = ["join", user_id] + + if room_id: + query += " AND messages.room_id=?" + query_args.append(room_id) + + (query, query_args) = self._append_stream_operations( + "messages", query, query_args, from_pkey, to_pkey, limit=limit + ) + + logger.debug("[SQL] %s : %s", query, query_args) + cursor = txn.execute(query, query_args) + return self._as_events(cursor, MessagesTable, from_pkey) + + def _get_message_rows_with_feedback(self, txn, user_id, from_pkey, to_pkey, + room_id, limit): + # this col represents the compressed feedback JSON as per spec + compressed_feedback_col = ( + "'[' || group_concat('{\"sender_id\":\"' || f.fb_sender_id" + + " || '\",\"feedback_type\":\"' || f.feedback_type" + + " || '\",\"content\":' || f.content || '}') || ']'" + ) + + global_msg_id_join = ("f.room_id = messages.room_id" + + " and f.msg_id = messages.msg_id" + + " and messages.user_id = f.msg_sender_id") + + select_query = ( + "SELECT messages.*, f.content AS fb_content, f.fb_sender_id" + + ", " + compressed_feedback_col + " AS compressed_fb" + + " FROM messages LEFT JOIN feedback f ON " + global_msg_id_join) + + current_membership_sub_query = ( + "(SELECT membership from room_memberships rm" + + " WHERE user_id=? AND room_id = rm.room_id" + + " ORDER BY id DESC LIMIT 1)") + + where = (" WHERE ? IN " + current_membership_sub_query + + " AND messages.room_id=?") + + query = select_query + where + query_args = ["join", user_id, room_id] + + (query, query_args) = self._append_stream_operations( + "messages", query, query_args, from_pkey, to_pkey, + limit=limit, group_by=" GROUP BY messages.id " + ) + + logger.debug("[SQL] %s : %s", query, query_args) + cursor = txn.execute(query, query_args) + + # convert the result set into events + entries = self.cursor_to_dict(cursor) + events = [] + for entry in entries: + # TODO we should spec the cursor > event mapping somewhere else. + event = {} + straight_mappings = ["msg_id", "user_id", "room_id"] + for key in straight_mappings: + event[key] = entry[key] + event["content"] = json.loads(entry["content"]) + if entry["compressed_fb"]: + event["feedback"] = json.loads(entry["compressed_fb"]) + events.append(event) + + latest_pkey = from_pkey if len(entries) == 0 else entries[-1]["id"] + + return (events, latest_pkey) + + def get_room_member_stream(self, user_id, from_key, to_key): + """Get all room membership events for this user between the given keys. + + Args: + user_id (str): The user who is requesting membership events. + from_key (int): The ID to start returning results from (exclusive). + to_key (int): The ID to stop returning results (exclusive). + Returns: + A tuple of rows (list of namedtuples), new_id(int) + """ + return self._db_pool.runInteraction( + self._get_room_member_rows, user_id, from_key, to_key + ) + + def _get_room_member_rows(self, txn, user_id, from_pkey, to_pkey): + # get all room membership events for rooms which the user is + # *currently* joined in on, or all invite events for this user. + current_membership_sub_query = ( + "(SELECT membership FROM room_memberships" + + " WHERE user_id=? AND room_id = rm.room_id" + + " ORDER BY id DESC LIMIT 1)") + + query = ("SELECT rm.* FROM room_memberships rm " + # all membership events for rooms you've currently joined. + + " WHERE (? IN " + current_membership_sub_query + # all invite membership events for this user + + " OR rm.membership=? AND user_id=?)" + + " AND rm.id > ?") + query_args = ["join", user_id, "invite", user_id, from_pkey] + + if to_pkey != -1: + query += " AND rm.id < ?" + query_args.append(to_pkey) + + cursor = txn.execute(query, query_args) + return self._as_events(cursor, RoomMemberTable, from_pkey) + + def get_feedback_stream(self, user_id, from_key, to_key, room_id, limit=0): + return self._db_pool.runInteraction( + self._get_feedback_rows, + user_id, from_key, to_key, room_id, limit + ) + + def _get_feedback_rows(self, txn, user_id, from_pkey, to_pkey, room_id, + limit): + # work out which rooms this user is joined in on and join them with + # the room id on the feedback table, bounded by the specified pkeys + + # get all messages where the *current* membership state is 'join' for + # this user in that room. + query = ( + "SELECT feedback.* FROM feedback WHERE ? IN " + + "(SELECT membership from room_memberships WHERE user_id=?" + + " AND room_id = feedback.room_id ORDER BY id DESC LIMIT 1)") + query_args = ["join", user_id] + + if room_id: + query += " AND feedback.room_id=?" + query_args.append(room_id) + + (query, query_args) = self._append_stream_operations( + "feedback", query, query_args, from_pkey, to_pkey, limit=limit + ) + + logger.debug("[SQL] %s : %s", query, query_args) + cursor = txn.execute(query, query_args) + return self._as_events(cursor, FeedbackTable, from_pkey) + + def get_room_data_stream(self, user_id, from_key, to_key, room_id, + limit=0): + return self._db_pool.runInteraction( + self._get_room_data_rows, + user_id, from_key, to_key, room_id, limit + ) + + def _get_room_data_rows(self, txn, user_id, from_pkey, to_pkey, room_id, + limit): + # work out which rooms this user is joined in on and join them with + # the room id on the feedback table, bounded by the specified pkeys + + # get all messages where the *current* membership state is 'join' for + # this user in that room. + query = ( + "SELECT room_data.* FROM room_data WHERE ? IN " + + "(SELECT membership from room_memberships WHERE user_id=?" + + " AND room_id = room_data.room_id ORDER BY id DESC LIMIT 1)") + query_args = ["join", user_id] + + if room_id: + query += " AND room_data.room_id=?" + query_args.append(room_id) + + (query, query_args) = self._append_stream_operations( + "room_data", query, query_args, from_pkey, to_pkey, limit=limit + ) + + logger.debug("[SQL] %s : %s", query, query_args) + cursor = txn.execute(query, query_args) + return self._as_events(cursor, RoomDataTable, from_pkey) + + def _append_stream_operations(self, table_name, query, query_args, + from_pkey, to_pkey, limit=None, + group_by=""): + LATEST_ROW = -1 + order_by = "" + if to_pkey > from_pkey: + if from_pkey != LATEST_ROW: + # e.g. from=5 to=9 >> from 5 to 9 >> id>5 AND id<9 + query += (" AND %s.id > ? AND %s.id < ?" % + (table_name, table_name)) + query_args.append(from_pkey) + query_args.append(to_pkey) + else: + # e.g. from=-1 to=5 >> from now to 5 >> id>5 ORDER BY id DESC + query += " AND %s.id > ? " % table_name + order_by = "ORDER BY id DESC" + query_args.append(to_pkey) + elif from_pkey > to_pkey: + if to_pkey != LATEST_ROW: + # from=9 to=5 >> from 9 to 5 >> id>5 AND id<9 ORDER BY id DESC + query += (" AND %s.id > ? AND %s.id < ? " % + (table_name, table_name)) + order_by = "ORDER BY id DESC" + query_args.append(to_pkey) + query_args.append(from_pkey) + else: + # from=5 to=-1 >> from 5 to now >> id>5 + query += " AND %s.id > ?" % table_name + query_args.append(from_pkey) + + query += group_by + order_by + + if limit and limit > 0: + query += " LIMIT ?" + query_args.append(str(limit)) + + return (query, query_args) + + def _as_events(self, cursor, table, from_pkey): + data_entries = table.decode_results(cursor) + last_pkey = from_pkey + if data_entries: + last_pkey = data_entries[-1].id + + events = [ + entry.as_event(self.event_factory).get_dict() + for entry in data_entries + ] + + return (events, last_pkey) diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py new file mode 100644 index 0000000000..aa41e2ad7f --- /dev/null +++ b/synapse/storage/transactions.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 matrix.org +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ._base import SQLBaseStore, Table +from .pdu import PdusTable + +from collections import namedtuple + +import logging + +logger = logging.getLogger(__name__) + + +class TransactionStore(SQLBaseStore): + """A collection of queries for handling PDUs. + """ + + def get_received_txn_response(self, transaction_id, origin): + """For an incoming transaction from a given origin, check if we have + already responded to it. If so, return the response code and response + body (as a dict). + + Args: + transaction_id (str) + origin(str) + + Returns: + tuple: None if we have not previously responded to + this transaction or a 2-tuple of (int, dict) + """ + + return self._db_pool.runInteraction( + self._get_received_txn_response, transaction_id, origin + ) + + def _get_received_txn_response(self, txn, transaction_id, origin): + where_clause = "transaction_id = ? AND origin = ?" + query = ReceivedTransactionsTable.select_statement(where_clause) + + txn.execute(query, (transaction_id, origin)) + + results = ReceivedTransactionsTable.decode_results(txn.fetchall()) + + if results and results[0].response_code: + return (results[0].response_code, results[0].response_json) + else: + return None + + def set_received_txn_response(self, transaction_id, origin, code, + response_dict): + """Persist the response we returened for an incoming transaction, and + should return for subsequent transactions with the same transaction_id + and origin. + + Args: + txn + transaction_id (str) + origin (str) + code (int) + response_json (str) + """ + + return self._db_pool.runInteraction( + self._set_received_txn_response, + transaction_id, origin, code, response_dict + ) + + def _set_received_txn_response(self, txn, transaction_id, origin, code, + response_json): + query = ( + "UPDATE %s " + "SET response_code = ?, response_json = ? " + "WHERE transaction_id = ? AND origin = ?" + ) % ReceivedTransactionsTable.table_name + + txn.execute(query, (code, response_json, transaction_id, origin)) + + def prep_send_transaction(self, transaction_id, destination, ts, pdu_list): + """Persists an outgoing transaction and calculates the values for the + previous transaction id list. + + This should be called before sending the transaction so that it has the + correct value for the `prev_ids` key. + + Args: + transaction_id (str) + destination (str) + ts (int) + pdu_list (list) + + Returns: + list: A list of previous transaction ids. + """ + + return self._db_pool.runInteraction( + self._prep_send_transaction, + transaction_id, destination, ts, pdu_list + ) + + def _prep_send_transaction(self, txn, transaction_id, destination, ts, + pdu_list): + + # First we find out what the prev_txs should be. + # Since we know that we are only sending one transaction at a time, + # we can simply take the last one. + query = "%s ORDER BY id DESC LIMIT 1" % ( + SentTransactions.select_statement("destination = ?"), + ) + + results = txn.execute(query, (destination,)) + results = SentTransactions.decode_results(results) + + prev_txns = [r.transaction_id for r in results] + + # Actually add the new transaction to the sent_transactions table. + + query = SentTransactions.insert_statement() + txn.execute(query, SentTransactions.EntryType( + None, + transaction_id=transaction_id, + destination=destination, + ts=ts, + response_code=0, + response_json=None + )) + + # Update the tx id -> pdu id mapping + + values = [ + (transaction_id, destination, pdu[0], pdu[1]) + for pdu in pdu_list + ] + + logger.debug("Inserting: %s", repr(values)) + + query = TransactionsToPduTable.insert_statement() + txn.executemany(query, values) + + return prev_txns + + def delivered_txn(self, transaction_id, destination, code, response_dict): + """Persists the response for an outgoing transaction. + + Args: + transaction_id (str) + destination (str) + code (int) + response_json (str) + """ + return self._db_pool.runInteraction( + self._delivered_txn, + transaction_id, destination, code, response_dict + ) + + def _delivered_txn(cls, txn, transaction_id, destination, + code, response_json): + query = ( + "UPDATE %s " + "SET response_code = ?, response_json = ? " + "WHERE transaction_id = ? AND destination = ?" + ) % SentTransactions.table_name + + txn.execute(query, (code, response_json, transaction_id, destination)) + + def get_transactions_after(self, transaction_id, destination): + """Get all transactions after a given local transaction_id. + + Args: + transaction_id (str) + destination (str) + + Returns: + list: A list of `ReceivedTransactionsTable.EntryType` + """ + return self._db_pool.runInteraction( + self._get_transactions_after, transaction_id, destination + ) + + def _get_transactions_after(cls, txn, transaction_id, destination): + where = ( + "destination = ? AND id > (select id FROM %s WHERE " + "transaction_id = ? AND destination = ?)" + ) % ( + SentTransactions.table_name + ) + query = SentTransactions.select_statement(where) + + txn.execute(query, (destination, transaction_id, destination)) + + return ReceivedTransactionsTable.decode_results(txn.fetchall()) + + def get_pdus_after_transaction(self, transaction_id, destination): + """For a given local transaction_id that we sent to a given destination + home server, return a list of PDUs that were sent to that destination + after it. + + Args: + txn + transaction_id (str) + destination (str) + + Returns + list: A list of PduTuple + """ + return self._db_pool.runInteraction( + self._get_pdus_after_transaction, + transaction_id, destination + ) + + def _get_pdus_after_transaction(self, txn, transaction_id, destination): + + # Query that first get's all transaction_ids with an id greater than + # the one given from the `sent_transactions` table. Then JOIN on this + # from the `tx->pdu` table to get a list of (pdu_id, origin) that + # specify the pdus that were sent in those transactions. + query = ( + "SELECT pdu_id, pdu_origin FROM %(tx_pdu)s as tp " + "INNER JOIN %(sent_tx)s as st " + "ON tp.transaction_id = st.transaction_id " + "AND tp.destination = st.destination " + "WHERE st.id > (" + "SELECT id FROM %(sent_tx)s " + "WHERE transaction_id = ? AND destination = ?" + ) % { + "tx_pdu": TransactionsToPduTable.table_name, + "sent_tx": SentTransactions.table_name, + } + + txn.execute(query, (transaction_id, destination)) + + pdus = PdusTable.decode_results(txn.fetchall()) + + return self._get_pdu_tuples(txn, pdus) + + +class ReceivedTransactionsTable(Table): + table_name = "received_transactions" + + fields = [ + "transaction_id", + "origin", + "ts", + "response_code", + "response_json", + "has_been_referenced", + ] + + EntryType = namedtuple("ReceivedTransactionsEntry", fields) + + +class SentTransactions(Table): + table_name = "sent_transactions" + + fields = [ + "id", + "transaction_id", + "destination", + "ts", + "response_code", + "response_json", + ] + + EntryType = namedtuple("SentTransactionsEntry", fields) + + +class TransactionsToPduTable(Table): + table_name = "transaction_id_to_pdu" + + fields = [ + "transaction_id", + "destination", + "pdu_id", + "pdu_origin", + ] + + EntryType = namedtuple("TransactionsToPduEntry", fields) |