summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py45
-rw-r--r--synapse/storage/_base.py49
-rw-r--r--synapse/storage/events.py17
-rw-r--r--synapse/storage/receipts.py67
-rw-r--r--synapse/storage/stream.py151
-rw-r--r--synapse/storage/tags.py7
-rw-r--r--synapse/storage/util/id_generators.py36
7 files changed, 222 insertions, 150 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7a3f6c4662..c8cab45f77 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -46,6 +46,9 @@ from .tags import TagsStore
 from .account_data import AccountDataStore
 
 
+from util.id_generators import IdGenerator, StreamIdGenerator
+
+
 import logging
 
 
@@ -58,6 +61,22 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120*1000
 
 
+def get_datastore(hs):
+    logger.info("getting called!")
+
+    conn = hs.get_db_conn()
+    try:
+        cur = conn.cursor()
+        cur.execute("SELECT MIN(stream_ordering) FROM events",)
+        rows = cur.fetchall()
+        min_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
+        min_token = min(min_token, -1)
+
+        return DataStore(conn, hs, min_token)
+    finally:
+        conn.close()
+
+
 class DataStore(RoomMemberStore, RoomStore,
                 RegistrationStore, StreamStore, ProfileStore,
                 PresenceStore, TransactionStore,
@@ -79,18 +98,36 @@ class DataStore(RoomMemberStore, RoomStore,
                 EventPushActionsStore
                 ):
 
-    def __init__(self, hs):
-        super(DataStore, self).__init__(hs)
+    def __init__(self, db_conn, hs, min_stream_token):
         self.hs = hs
 
-        self.min_token_deferred = self._get_min_token()
-        self.min_token = None
+        self.min_stream_token = min_stream_token
 
         self.client_ip_last_seen = Cache(
             name="client_ip_last_seen",
             keylen=4,
         )
 
+        self._stream_id_gen = StreamIdGenerator(
+            db_conn, "events", "stream_ordering"
+        )
+        self._receipts_id_gen = StreamIdGenerator(
+            db_conn, "receipts_linearized", "stream_id"
+        )
+        self._account_data_id_gen = StreamIdGenerator(
+            db_conn, "account_data_max_stream_id", "stream_id"
+        )
+
+        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
+        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
+        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
+        self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
+        self._pushers_id_gen = IdGenerator("pushers", "id", self)
+        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
+        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+
+        super(DataStore, self).__init__(hs)
+
     @defer.inlineCallbacks
     def insert_client_ip(self, user, access_token, ip, user_agent):
         now = int(self._clock.time_msec())
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 90d7aee94a..5e77320540 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -15,13 +15,11 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.caches.descriptors import Cache
 import synapse.metrics
 
-from util.id_generators import IdGenerator, StreamIdGenerator
 
 from twisted.internet import defer
 
@@ -175,16 +173,6 @@ class SQLBaseStore(object):
 
         self.database_engine = hs.database_engine
 
-        self._stream_id_gen = StreamIdGenerator("events", "stream_ordering")
-        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
-        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
-        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
-        self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
-        self._pushers_id_gen = IdGenerator("pushers", "id", self)
-        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
-        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
-        self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id")
-
     def start_profiling(self):
         self._previous_loop_ts = self._clock.time_msec()
 
@@ -345,7 +333,8 @@ class SQLBaseStore(object):
 
         defer.returnValue(result)
 
-    def cursor_to_dict(self, cursor):
+    @staticmethod
+    def cursor_to_dict(cursor):
         """Converts a SQL cursor into an list of dicts.
 
         Args:
@@ -402,8 +391,8 @@ class SQLBaseStore(object):
             if not or_ignore:
                 raise
 
-    @log_function
-    def _simple_insert_txn(self, txn, table, values):
+    @staticmethod
+    def _simple_insert_txn(txn, table, values):
         keys, vals = zip(*values.items())
 
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -414,7 +403,8 @@ class SQLBaseStore(object):
 
         txn.execute(sql, vals)
 
-    def _simple_insert_many_txn(self, txn, table, values):
+    @staticmethod
+    def _simple_insert_many_txn(txn, table, values):
         if not values:
             return
 
@@ -537,9 +527,10 @@ class SQLBaseStore(object):
             table, keyvalues, retcol, allow_none=allow_none,
         )
 
-    def _simple_select_one_onecol_txn(self, txn, table, keyvalues, retcol,
+    @classmethod
+    def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
                                       allow_none=False):
-        ret = self._simple_select_onecol_txn(
+        ret = cls._simple_select_onecol_txn(
             txn,
             table=table,
             keyvalues=keyvalues,
@@ -554,7 +545,8 @@ class SQLBaseStore(object):
             else:
                 raise StoreError(404, "No row found")
 
-    def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
+    @staticmethod
+    def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
         sql = (
             "SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
         ) % {
@@ -603,7 +595,8 @@ class SQLBaseStore(object):
             table, keyvalues, retcols
         )
 
-    def _simple_select_list_txn(self, txn, table, keyvalues, retcols):
+    @classmethod
+    def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -627,7 +620,7 @@ class SQLBaseStore(object):
             )
             txn.execute(sql)
 
-        return self.cursor_to_dict(txn)
+        return cls.cursor_to_dict(txn)
 
     @defer.inlineCallbacks
     def _simple_select_many_batch(self, table, column, iterable, retcols,
@@ -662,7 +655,8 @@ class SQLBaseStore(object):
 
         defer.returnValue(results)
 
-    def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
+    @classmethod
+    def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
@@ -699,7 +693,7 @@ class SQLBaseStore(object):
             )
 
         txn.execute(sql, values)
-        return self.cursor_to_dict(txn)
+        return cls.cursor_to_dict(txn)
 
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            desc="_simple_update_one"):
@@ -726,7 +720,8 @@ class SQLBaseStore(object):
             table, keyvalues, updatevalues,
         )
 
-    def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
+    @staticmethod
+    def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
         update_sql = "UPDATE %s SET %s WHERE %s" % (
             table,
             ", ".join("%s = ?" % (k,) for k in updatevalues),
@@ -743,7 +738,8 @@ class SQLBaseStore(object):
         if txn.rowcount > 1:
             raise StoreError(500, "More than one row matched")
 
-    def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
+    @staticmethod
+    def _simple_select_one_txn(txn, table, keyvalues, retcols,
                                allow_none=False):
         select_sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
@@ -784,7 +780,8 @@ class SQLBaseStore(object):
                 raise StoreError(500, "more than one row matched")
         return self.runInteraction(desc, func)
 
-    def _simple_delete_txn(self, txn, table, keyvalues):
+    @staticmethod
+    def _simple_delete_txn(txn, table, keyvalues):
         sql = "DELETE FROM %s WHERE %s" % (
             table,
             " AND ".join("%s = ?" % (k, ) for k in keyvalues)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index ba368a3eca..80187722ea 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -66,11 +66,9 @@ class EventsStore(SQLBaseStore):
             return
 
         if backfilled:
-            if not self.min_token_deferred.called:
-                yield self.min_token_deferred
-            start = self.min_token - 1
-            self.min_token -= len(events_and_contexts) + 1
-            stream_orderings = range(start, self.min_token, -1)
+            start = self.min_stream_token - 1
+            self.min_stream_token -= len(events_and_contexts) + 1
+            stream_orderings = range(start, self.min_stream_token, -1)
 
             @contextmanager
             def stream_ordering_manager():
@@ -107,10 +105,8 @@ class EventsStore(SQLBaseStore):
                       is_new_state=True, current_state=None):
         stream_ordering = None
         if backfilled:
-            if not self.min_token_deferred.called:
-                yield self.min_token_deferred
-            self.min_token -= 1
-            stream_ordering = self.min_token
+            self.min_stream_token -= 1
+            stream_ordering = self.min_stream_token
 
         if stream_ordering is None:
             stream_ordering_manager = yield self._stream_id_gen.get_next(self)
@@ -132,6 +128,9 @@ class EventsStore(SQLBaseStore):
                     is_new_state=is_new_state,
                     current_state=current_state,
                 )
+                self._events_stream_cache.room_has_changed(
+                    None, event.room_id, stream_ordering
+                )
         except _RollbackButIsFineException:
             pass
 
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index c4232bdc65..7118368d97 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -15,11 +15,10 @@
 
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
-from synapse.util.caches import cache_counter, caches_by_name
+from synapse.util.caches.room_change_cache import RoomStreamChangeCache
 
 from twisted.internet import defer
 
-from blist import sorteddict
 import logging
 import ujson as json
 
@@ -31,7 +30,9 @@ class ReceiptsStore(SQLBaseStore):
     def __init__(self, hs):
         super(ReceiptsStore, self).__init__(hs)
 
-        self._receipts_stream_cache = _RoomStreamChangeCache()
+        self._receipts_stream_cache = RoomStreamChangeCache(
+            "ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token(None)
+        )
 
     @cached(num_args=2)
     def get_receipts_for_room(self, room_id, receipt_type):
@@ -368,63 +369,3 @@ class ReceiptsStore(SQLBaseStore):
                 "data": json.dumps(data),
             }
         )
-
-
-class _RoomStreamChangeCache(object):
-    """Keeps track of the stream_id of the latest change in rooms.
-
-    Given a list of rooms and stream key, it will give a subset of rooms that
-    may have changed since that key. If the key is too old then the cache
-    will simply return all rooms.
-    """
-    def __init__(self, size_of_cache=10000):
-        self._size_of_cache = size_of_cache
-        self._room_to_key = {}
-        self._cache = sorteddict()
-        self._earliest_key = None
-        self.name = "ReceiptsRoomChangeCache"
-        caches_by_name[self.name] = self._cache
-
-    @defer.inlineCallbacks
-    def get_rooms_changed(self, store, room_ids, key):
-        """Returns subset of room ids that have had new receipts since the
-        given key. If the key is too old it will just return the given list.
-        """
-        if key > (yield self._get_earliest_key(store)):
-            keys = self._cache.keys()
-            i = keys.bisect_right(key)
-
-            result = set(
-                self._cache[k] for k in keys[i:]
-            ).intersection(room_ids)
-
-            cache_counter.inc_hits(self.name)
-        else:
-            result = room_ids
-            cache_counter.inc_misses(self.name)
-
-        defer.returnValue(result)
-
-    @defer.inlineCallbacks
-    def room_has_changed(self, store, room_id, key):
-        """Informs the cache that the room has been changed at the given key.
-        """
-        if key > (yield self._get_earliest_key(store)):
-            old_key = self._room_to_key.get(room_id, None)
-            if old_key:
-                key = max(key, old_key)
-                self._cache.pop(old_key, None)
-            self._cache[key] = room_id
-
-            while len(self._cache) > self._size_of_cache:
-                k, r = self._cache.popitem()
-                self._earliest_key = max(k, self._earliest_key)
-                self._room_to_key.pop(r, None)
-
-    @defer.inlineCallbacks
-    def _get_earliest_key(self, store):
-        if self._earliest_key is None:
-            self._earliest_key = yield store.get_max_receipt_stream_id()
-            self._earliest_key = int(self._earliest_key)
-
-        defer.returnValue(self._earliest_key)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 02b1913e26..0b22251790 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -37,6 +37,7 @@ from twisted.internet import defer
 
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks
+from synapse.util.caches.room_change_cache import RoomStreamChangeCache
 from synapse.api.constants import EventTypes
 from synapse.types import RoomStreamToken
 from synapse.util.logutils import log_function
@@ -77,6 +78,12 @@ def upper_bound(token):
 
 
 class StreamStore(SQLBaseStore):
+    def __init__(self, hs):
+        super(StreamStore, self).__init__(hs)
+
+        self._events_stream_cache = RoomStreamChangeCache(
+            "EventsRoomStreamChangeCache", self._stream_id_gen.get_max_token(None)
+        )
 
     @defer.inlineCallbacks
     def get_appservice_room_stream(self, service, from_key, to_key, limit=0):
@@ -157,6 +164,134 @@ class StreamStore(SQLBaseStore):
         results = yield self.runInteraction("get_appservice_room_stream", f)
         defer.returnValue(results)
 
+    @defer.inlineCallbacks
+    def get_room_events_stream_for_rooms(self, room_ids, from_key, to_key, limit=0):
+        from_id = RoomStreamToken.parse_stream_token(from_key).stream
+
+        room_ids = yield self._events_stream_cache.get_rooms_changed(
+            self, room_ids, from_id
+        )
+
+        if not room_ids:
+            defer.returnValue({})
+
+        results = {}
+        room_ids = list(room_ids)
+        for rm_ids in (room_ids[i:i+20] for i in xrange(0, len(room_ids), 20)):
+            res = yield defer.gatherResults([
+                self.get_room_events_stream_for_room(
+                    room_id, from_key, to_key, limit
+                ).addCallback(lambda r, rm: (rm, r), room_id)
+                for room_id in room_ids
+            ])
+            results.update(dict(res))
+
+        defer.returnValue(results)
+
+    @defer.inlineCallbacks
+    def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0):
+        if from_key is not None:
+            from_id = RoomStreamToken.parse_stream_token(from_key).stream
+        else:
+            from_id = None
+        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+        if from_key == to_key:
+            defer.returnValue(([], from_key))
+
+        has_changed = yield self._events_stream_cache.get_room_has_changed(
+            room_id, from_id
+        )
+
+        if not has_changed:
+            defer.returnValue(([], from_key))
+
+        def f(txn):
+            if from_id is not None:
+                sql = (
+                    "SELECT event_id, stream_ordering FROM events WHERE"
+                    " room_id = ?"
+                    " AND not outlier"
+                    " AND stream_ordering > ? AND stream_ordering <= ?"
+                    " ORDER BY stream_ordering DESC LIMIT ?"
+                )
+                txn.execute(sql, (room_id, from_id, to_id, limit))
+            else:
+                sql = (
+                    "SELECT event_id, stream_ordering FROM events WHERE"
+                    " room_id = ?"
+                    " AND not outlier"
+                    " AND stream_ordering <= ?"
+                    " ORDER BY stream_ordering DESC LIMIT ?"
+                )
+                txn.execute(sql, (room_id, to_id, limit))
+
+            rows = self.cursor_to_dict(txn)
+
+            ret = self._get_events_txn(
+                txn,
+                [r["event_id"] for r in rows],
+                get_prev_content=True
+            )
+
+            ret.reverse()
+
+            self._set_before_and_after(ret, rows)
+
+            if rows:
+                key = "s%d" % min(r["stream_ordering"] for r in rows)
+            else:
+                # Assume we didn't get anything because there was nothing to
+                # get.
+                key = from_key
+
+            return ret, key
+        res = yield self.runInteraction("get_room_events_stream_for_room", f)
+        defer.returnValue(res)
+
+    def get_room_changes_for_user(self, user_id, from_key, to_key):
+        if from_key is not None:
+            from_id = RoomStreamToken.parse_stream_token(from_key).stream
+        else:
+            from_id = None
+        to_id = RoomStreamToken.parse_stream_token(to_key).stream
+
+        if from_key == to_key:
+            return defer.succeed([])
+
+        def f(txn):
+            if from_id is not None:
+                sql = (
+                    "SELECT m.event_id, stream_ordering FROM events AS e,"
+                    " room_memberships AS m"
+                    " WHERE e.event_id = m.event_id"
+                    " AND m.user_id = ?"
+                    " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
+                    " ORDER BY e.stream_ordering ASC"
+                )
+                txn.execute(sql, (user_id, from_id, to_id,))
+            else:
+                sql = (
+                    "SELECT m.event_id, stream_ordering FROM events AS e,"
+                    " room_memberships AS m"
+                    " WHERE e.event_id = m.event_id"
+                    " AND m.user_id = ?"
+                    " AND stream_ordering <= ?"
+                    " ORDER BY stream_ordering ASC"
+                )
+                txn.execute(sql, (user_id, to_id,))
+            rows = self.cursor_to_dict(txn)
+
+            ret = self._get_events_txn(
+                txn,
+                [r["event_id"] for r in rows],
+                get_prev_content=True
+            )
+
+            return ret
+
+        return self.runInteraction("get_room_changes_for_user", f)
+
     @log_function
     def get_room_events_stream(
         self,
@@ -174,7 +309,8 @@ class StreamStore(SQLBaseStore):
                 "SELECT c.room_id FROM history_visibility AS h"
                 " INNER JOIN current_state_events AS c"
                 " ON h.event_id = c.event_id"
-                " WHERE c.room_id IN (%s) AND h.history_visibility = 'world_readable'" % (
+                " WHERE c.room_id IN (%s)"
+                " AND h.history_visibility = 'world_readable'" % (
                     ",".join(map(lambda _: "?", room_ids))
                 )
             )
@@ -444,19 +580,6 @@ class StreamStore(SQLBaseStore):
         rows = txn.fetchall()
         return rows[0][0] if rows else 0
 
-    @defer.inlineCallbacks
-    def _get_min_token(self):
-        row = yield self._execute(
-            "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
-        )
-
-        self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
-        self.min_token = min(self.min_token, -1)
-
-        logger.debug("min_token is: %s", self.min_token)
-
-        defer.returnValue(self.min_token)
-
     @staticmethod
     def _set_before_and_after(events, rows):
         for event, row in zip(events, rows):
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index ed9c91e5ea..4c39e07cbd 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -16,7 +16,6 @@
 from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cached
 from twisted.internet import defer
-from .util.id_generators import StreamIdGenerator
 
 import ujson as json
 import logging
@@ -25,12 +24,6 @@ logger = logging.getLogger(__name__)
 
 
 class TagsStore(SQLBaseStore):
-    def __init__(self, hs):
-        super(TagsStore, self).__init__(hs)
-
-        self._account_data_id_gen = StreamIdGenerator(
-            "account_data_max_stream_id", "stream_id"
-        )
 
     def get_max_account_data_stream_id(self):
         """Get the current max stream id for the private user data stream
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f58bf7fd2c..5c522f4ab9 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -72,28 +72,24 @@ class StreamIdGenerator(object):
         with stream_id_gen.get_next_txn(txn) as stream_id:
             # ... persist event ...
     """
-    def __init__(self, table, column):
+    def __init__(self, db_conn, table, column):
         self.table = table
         self.column = column
 
         self._lock = threading.Lock()
 
-        self._current_max = None
+        cur = db_conn.cursor()
+        self._current_max = self._get_or_compute_current_max(cur)
+        cur.close()
+
         self._unfinished_ids = deque()
 
-    @defer.inlineCallbacks
     def get_next(self, store):
         """
         Usage:
             with yield stream_id_gen.get_next as stream_id:
                 # ... persist event ...
         """
-        if not self._current_max:
-            yield store.runInteraction(
-                "_compute_current_max",
-                self._get_or_compute_current_max,
-            )
-
         with self._lock:
             self._current_max += 1
             next_id = self._current_max
@@ -108,21 +104,14 @@ class StreamIdGenerator(object):
                 with self._lock:
                     self._unfinished_ids.remove(next_id)
 
-        defer.returnValue(manager())
+        return manager()
 
-    @defer.inlineCallbacks
     def get_next_mult(self, store, n):
         """
         Usage:
             with yield stream_id_gen.get_next(store, n) as stream_ids:
                 # ... persist events ...
         """
-        if not self._current_max:
-            yield store.runInteraction(
-                "_compute_current_max",
-                self._get_or_compute_current_max,
-            )
-
         with self._lock:
             next_ids = range(self._current_max + 1, self._current_max + n + 1)
             self._current_max += n
@@ -139,24 +128,17 @@ class StreamIdGenerator(object):
                     for next_id in next_ids:
                         self._unfinished_ids.remove(next_id)
 
-        defer.returnValue(manager())
+        return manager()
 
-    @defer.inlineCallbacks
     def get_max_token(self, store):
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
         """
-        if not self._current_max:
-            yield store.runInteraction(
-                "_compute_current_max",
-                self._get_or_compute_current_max,
-            )
-
         with self._lock:
             if self._unfinished_ids:
-                defer.returnValue(self._unfinished_ids[0] - 1)
+                return self._unfinished_ids[0] - 1
 
-            defer.returnValue(self._current_max)
+            return self._current_max
 
     def _get_or_compute_current_max(self, txn):
         with self._lock: