summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py6
-rw-r--r--synapse/storage/_base.py22
-rw-r--r--synapse/storage/appservice.py114
-rw-r--r--synapse/storage/event_federation.py51
-rw-r--r--synapse/storage/roommember.py36
-rw-r--r--synapse/storage/stream.py83
6 files changed, 279 insertions, 33 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d16e7b8fac..d6ec446bd2 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -74,7 +74,7 @@ SCHEMAS = [
 
 # Remember to update this number every time an incompatible change is made to
 # database schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 13
+SCHEMA_VERSION = 14
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -633,11 +633,11 @@ def prepare_database(db_conn):
 
             # Run every version since after the current version.
             for v in range(user_version + 1, SCHEMA_VERSION + 1):
-                if v == 10:
+                if v in (10, 14,):
                     raise UpgradeDatabaseException(
                         "No delta for version 10"
                     )
-                sql_script = read_schema("delta/v%d" % (v))
+                sql_script = read_schema("delta/v%d" % (v,))
                 c.executescript(sql_script)
 
             db_conn.commit()
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index c98dd36aed..3725c9795d 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -450,7 +450,8 @@ class SQLBaseStore(object):
 
         Args:
             table : string giving the table name
-            keyvalues : dict of column names and values to select the rows with
+            keyvalues : dict of column names and values to select the rows with,
+            or None to not apply a WHERE clause.
             retcols : list of strings giving the names of the columns to return
         """
         return self.runInteraction(
@@ -469,13 +470,20 @@ class SQLBaseStore(object):
             keyvalues : dict of column names and values to select the rows with
             retcols : list of strings giving the names of the columns to return
         """
-        sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
-            ", ".join(retcols),
-            table,
-            " AND ".join("%s = ?" % (k, ) for k in keyvalues)
-        )
+        if keyvalues:
+            sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
+                ", ".join(retcols),
+                table,
+                " AND ".join("%s = ?" % (k, ) for k in keyvalues)
+            )
+            txn.execute(sql, keyvalues.values())
+        else:
+            sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
+                ", ".join(retcols),
+                table
+            )
+            txn.execute(sql)
 
-        txn.execute(sql, keyvalues.values())
         return self.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index dc3666efd4..e30265750a 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -13,22 +13,32 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+import simplejson
+from simplejson import JSONDecodeError
 from twisted.internet import defer
 
+from synapse.api.constants import Membership
 from synapse.api.errors import StoreError
 from synapse.appservice import ApplicationService
+from synapse.storage.roommember import RoomsForUser
 from ._base import SQLBaseStore
 
 
 logger = logging.getLogger(__name__)
 
 
+def log_failure(failure):
+    logger.error("Failed to detect application services: %s", failure.value)
+    logger.error(failure.getTraceback())
+
+
 class ApplicationServiceStore(SQLBaseStore):
 
     def __init__(self, hs):
         super(ApplicationServiceStore, self).__init__(hs)
         self.services_cache = []
         self.cache_defer = self._populate_cache()
+        self.cache_defer.addErrback(log_failure)
 
     @defer.inlineCallbacks
     def unregister_app_service(self, token):
@@ -128,11 +138,11 @@ class ApplicationServiceStore(SQLBaseStore):
         )
         for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST):
             if ns_str in service.namespaces:
-                for regex in service.namespaces[ns_str]:
+                for regex_obj in service.namespaces[ns_str]:
                     txn.execute(
                         "INSERT INTO application_services_regex("
                         "as_id, namespace, regex) values(?,?,?)",
-                        (as_id, ns_int, regex)
+                        (as_id, ns_int, simplejson.dumps(regex_obj))
                     )
         return True
 
@@ -151,8 +161,31 @@ class ApplicationServiceStore(SQLBaseStore):
         defer.returnValue(self.services_cache)
 
     @defer.inlineCallbacks
+    def get_app_service_by_user_id(self, user_id):
+        """Retrieve an application service from their user ID.
+
+        All application services have associated with them a particular user ID.
+        There is no distinguishing feature on the user ID which indicates it
+        represents an application service. This function allows you to map from
+        a user ID to an application service.
+
+        Args:
+            user_id(str): The user ID to see if it is an application service.
+        Returns:
+            synapse.appservice.ApplicationService or None.
+        """
+
+        yield self.cache_defer  # make sure the cache is ready
+
+        for service in self.services_cache:
+            if service.sender == user_id:
+                defer.returnValue(service)
+                return
+        defer.returnValue(None)
+
+    @defer.inlineCallbacks
     def get_app_service_by_token(self, token, from_cache=True):
-        """Get the application service with the given token.
+        """Get the application service with the given appservice token.
 
         Args:
             token (str): The application service token.
@@ -173,6 +206,77 @@ class ApplicationServiceStore(SQLBaseStore):
         # TODO: The from_cache=False impl
         # TODO: This should be JOINed with the application_services_regex table.
 
+    def get_app_service_rooms(self, service):
+        """Get a list of RoomsForUser for this application service.
+
+        Application services may be "interested" in lots of rooms depending on
+        the room ID, the room aliases, or the members in the room. This function
+        takes all of these into account and returns a list of RoomsForUser which
+        represent the entire list of room IDs that this application service
+        wants to know about.
+
+        Args:
+            service: The application service to get a room list for.
+        Returns:
+            A list of RoomsForUser.
+        """
+        return self.runInteraction(
+            "get_app_service_rooms",
+            self._get_app_service_rooms_txn,
+            service,
+        )
+
+    def _get_app_service_rooms_txn(self, txn, service):
+        # get all rooms matching the room ID regex.
+        room_entries = self._simple_select_list_txn(
+            txn=txn, table="rooms", keyvalues=None, retcols=["room_id"]
+        )
+        matching_room_list = set([
+            r["room_id"] for r in room_entries if
+            service.is_interested_in_room(r["room_id"])
+        ])
+
+        # resolve room IDs for matching room alias regex.
+        room_alias_mappings = self._simple_select_list_txn(
+            txn=txn, table="room_aliases", keyvalues=None,
+            retcols=["room_id", "room_alias"]
+        )
+        matching_room_list |= set([
+            r["room_id"] for r in room_alias_mappings if
+            service.is_interested_in_alias(r["room_alias"])
+        ])
+
+        # get all rooms for every user for this AS. This is scoped to users on
+        # this HS only.
+        user_list = self._simple_select_list_txn(
+            txn=txn, table="users", keyvalues=None, retcols=["name"]
+        )
+        user_list = [
+            u["name"] for u in user_list if
+            service.is_interested_in_user(u["name"])
+        ]
+        rooms_for_user_matching_user_id = set()  # RoomsForUser list
+        for user_id in user_list:
+            # FIXME: This assumes this store is linked with RoomMemberStore :(
+            rooms_for_user = self._get_rooms_for_user_where_membership_is_txn(
+                txn=txn,
+                user_id=user_id,
+                membership_list=[Membership.JOIN]
+            )
+            rooms_for_user_matching_user_id |= set(rooms_for_user)
+
+        # make RoomsForUser tuples for room ids and aliases which are not in the
+        # main rooms_for_user_list - e.g. they are rooms which do not have AS
+        # registered users in it.
+        known_room_ids = [r.room_id for r in rooms_for_user_matching_user_id]
+        missing_rooms_for_user = [
+            RoomsForUser(r, service.sender, "join") for r in
+            matching_room_list if r not in known_room_ids
+        ]
+        rooms_for_user_matching_user_id |= set(missing_rooms_for_user)
+
+        return rooms_for_user_matching_user_id
+
     @defer.inlineCallbacks
     def _populate_cache(self):
         """Populates the ApplicationServiceCache from the database."""
@@ -215,10 +319,12 @@ class ApplicationServiceStore(SQLBaseStore):
             try:
                 services[as_token]["namespaces"][
                     ApplicationService.NS_LIST[ns_int]].append(
-                    res["regex"]
+                    simplejson.loads(res["regex"])
                 )
             except IndexError:
                 logger.error("Bad namespace enum '%s'. %s", ns_int, res)
+            except JSONDecodeError:
+                logger.error("Bad regex object '%s'", res["regex"])
 
         # TODO get last successful txn id f.e. service
         for service in services.values():
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 3fbc090224..2deda8ac50 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -64,6 +64,9 @@ class EventFederationStore(SQLBaseStore):
             for f in front:
                 txn.execute(base_sql, (f,))
                 new_front.update([r[0] for r in txn.fetchall()])
+
+            new_front -= results
+
             front = new_front
             results.update(front)
 
@@ -378,3 +381,51 @@ class EventFederationStore(SQLBaseStore):
             event_results += new_front
 
         return self._get_events_txn(txn, event_results)
+
+    def get_missing_events(self, room_id, earliest_events, latest_events,
+                           limit, min_depth):
+        return self.runInteraction(
+            "get_missing_events",
+            self._get_missing_events,
+            room_id, earliest_events, latest_events, limit, min_depth
+        )
+
+    def _get_missing_events(self, txn, room_id, earliest_events, latest_events,
+                            limit, min_depth):
+
+        earliest_events = set(earliest_events)
+        front = set(latest_events) - earliest_events
+
+        event_results = set()
+
+        query = (
+            "SELECT prev_event_id FROM event_edges "
+            "WHERE room_id = ? AND event_id = ? AND is_state = 0 "
+            "LIMIT ?"
+        )
+
+        while front and len(event_results) < limit:
+            new_front = set()
+            for event_id in front:
+                txn.execute(
+                    query,
+                    (room_id, event_id, limit - len(event_results))
+                )
+
+                for e_id, in txn.fetchall():
+                    new_front.add(e_id)
+
+            new_front -= earliest_events
+            new_front -= event_results
+
+            front = new_front
+            event_results |= new_front
+
+        events = self._get_events_txn(txn, event_results)
+
+        events = sorted(
+            [ev for ev in events if ev.depth >= min_depth],
+            key=lambda e: e.depth,
+        )
+
+        return events[:limit]
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 58aa376c20..65ffb4627f 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -180,6 +180,14 @@ class RoomMemberStore(SQLBaseStore):
         if not membership_list:
             return defer.succeed(None)
 
+        return self.runInteraction(
+            "get_rooms_for_user_where_membership_is",
+            self._get_rooms_for_user_where_membership_is_txn,
+            user_id, membership_list
+        )
+
+    def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
+                                                    membership_list):
         where_clause = "user_id = ? AND (%s)" % (
             " OR ".join(["membership = ?" for _ in membership_list]),
         )
@@ -187,24 +195,18 @@ class RoomMemberStore(SQLBaseStore):
         args = [user_id]
         args.extend(membership_list)
 
-        def f(txn):
-            sql = (
-                "SELECT m.room_id, m.sender, m.membership"
-                " FROM room_memberships as m"
-                " INNER JOIN current_state_events as c"
-                " ON m.event_id = c.event_id"
-                " WHERE %s"
-            ) % (where_clause,)
-
-            txn.execute(sql, args)
-            return [
-                RoomsForUser(**r) for r in self.cursor_to_dict(txn)
-            ]
+        sql = (
+            "SELECT m.room_id, m.sender, m.membership"
+            " FROM room_memberships as m"
+            " INNER JOIN current_state_events as c"
+            " ON m.event_id = c.event_id"
+            " WHERE %s"
+        ) % (where_clause,)
 
-        return self.runInteraction(
-            "get_rooms_for_user_where_membership_is",
-            f
-        )
+        txn.execute(sql, args)
+        return [
+            RoomsForUser(**r) for r in self.cursor_to_dict(txn)
+        ]
 
     def get_joined_hosts_for_room(self, room_id):
         return self._simple_select_onecol(
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 3ccb6f8a61..09bc522210 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -36,6 +36,7 @@ what sort order was used:
 from twisted.internet import defer
 
 from ._base import SQLBaseStore
+from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
 from synapse.util.logutils import log_function
 
@@ -127,6 +128,85 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
 
 
 class StreamStore(SQLBaseStore):
+
+    @defer.inlineCallbacks
+    def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
+        # NB this lives here instead of appservice.py so we can reuse the
+        # 'private' StreamToken class in this file.
+        if limit:
+            limit = max(limit, MAX_STREAM_SIZE)
+        else:
+            limit = MAX_STREAM_SIZE
+
+        # From and to keys should be integers from ordering.
+        from_id = _StreamToken.parse_stream_token(from_key)
+        to_id = _StreamToken.parse_stream_token(to_key)
+
+        if from_key == to_key:
+            defer.returnValue(([], to_key))
+            return
+
+        # select all the events between from/to with a sensible limit
+        sql = (
+            "SELECT e.event_id, e.room_id, e.type, s.state_key, "
+            "e.stream_ordering FROM events AS e LEFT JOIN state_events as s ON "
+            "e.event_id = s.event_id "
+            "WHERE e.stream_ordering > ? AND e.stream_ordering <= ? "
+            "ORDER BY stream_ordering ASC LIMIT %(limit)d "
+        ) % {
+            "limit": limit
+        }
+
+        def f(txn):
+            # pull out all the events between the tokens
+            txn.execute(sql, (from_id.stream, to_id.stream,))
+            rows = self.cursor_to_dict(txn)
+
+            # Logic:
+            #  - We want ALL events which match the AS room_id regex
+            #  - We want ALL events which match the rooms represented by the AS
+            #    room_alias regex
+            #  - We want ALL events for rooms that AS users have joined.
+            # This is currently supported via get_app_service_rooms (which is
+            # used for the Notifier listener rooms). We can't reasonably make a
+            # SQL query for these room IDs, so we'll pull all the events between
+            # from/to and filter in python.
+            rooms_for_as = self._get_app_service_rooms_txn(txn, service)
+            room_ids_for_as = [r.room_id for r in rooms_for_as]
+
+            def app_service_interested(row):
+                if row["room_id"] in room_ids_for_as:
+                    return True
+
+                if row["type"] == EventTypes.Member:
+                    if service.is_interested_in_user(row.get("state_key")):
+                        return True
+                return False
+
+            ret = self._get_events_txn(
+                txn,
+                # apply the filter on the room id list
+                [
+                    r["event_id"] for r in rows
+                    if app_service_interested(r)
+                ],
+                get_prev_content=True
+            )
+
+            self._set_before_and_after(ret, rows)
+
+            if rows:
+                key = "s%d" % max(r["stream_ordering"] for r in rows)
+            else:
+                # Assume we didn't get anything because there was nothing to
+                # get.
+                key = to_key
+
+            return ret, key
+
+        results = yield self.runInteraction("get_appservice_room_stream", f)
+        defer.returnValue(results)
+
     @log_function
     def get_room_events_stream(self, user_id, from_key, to_key, room_id,
                                limit=0, with_feedback=False):
@@ -184,8 +264,7 @@ class StreamStore(SQLBaseStore):
             self._set_before_and_after(ret, rows)
 
             if rows:
-                key = "s%d" % max([r["stream_ordering"] for r in rows])
-
+                key = "s%d" % max(r["stream_ordering"] for r in rows)
             else:
                 # Assume we didn't get anything because there was nothing to
                 # get.