summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py76
-rw-r--r--synapse/storage/_base.py159
-rw-r--r--synapse/storage/end_to_end_keys.py2
-rw-r--r--synapse/storage/events.py37
-rw-r--r--synapse/storage/roommember.py25
-rw-r--r--synapse/storage/stream.py14
6 files changed, 278 insertions, 35 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index b9968debe5..d604e7668f 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -297,6 +297,82 @@ class DataStore(RoomMemberStore, RoomStore,
             desc="get_user_ip_and_agents",
         )
 
+    def get_users(self):
+        """Function to reterive a list of users in users table.
+
+        Args:
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self._simple_select_list(
+            table="users",
+            keyvalues={},
+            retcols=[
+                "name",
+                "password_hash",
+                "is_guest",
+                "admin"
+            ],
+            desc="get_users",
+        )
+
+    def get_users_paginate(self, order, start, limit):
+        """Function to reterive a paginated list of users from
+        users list. This will return a json object, which contains
+        list of users and the total number of users in users table.
+
+        Args:
+            order (str): column name to order the select by this column
+            start (int): start number to begin the query from
+            limit (int): number of rows to reterive
+        Returns:
+            defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+        """
+        is_guest = 0
+        i_start = (int)(start)
+        i_limit = (int)(limit)
+        return self.get_user_list_paginate(
+            table="users",
+            keyvalues={
+                "is_guest": is_guest
+            },
+            pagevalues=[
+                order,
+                i_limit,
+                i_start
+            ],
+            retcols=[
+                "name",
+                "password_hash",
+                "is_guest",
+                "admin"
+            ],
+            desc="get_users_paginate",
+        )
+
+    def search_users(self, term):
+        """Function to search users list for one or more users with
+        the matched term.
+
+        Args:
+            term (str): search term
+            col (str): column to query term should be matched to
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self._simple_search_list(
+            table="users",
+            term=term,
+            col="name",
+            retcols=[
+                "name",
+                "password_hash",
+                "is_guest",
+                "admin"
+            ],
+            desc="search_users",
+        )
+
 
 def are_all_users_on_domain(txn, database_engine, domain):
     sql = database_engine.convert_param_style(
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 05374682fd..b0dc391190 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -934,6 +934,165 @@ class SQLBaseStore(object):
         else:
             return 0
 
+    def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
+                                     desc="_simple_select_list_paginate"):
+        """Executes a SELECT query on the named table with start and limit,
+        of row numbers, which may return zero or number of rows from start to limit,
+        returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            keyvalues (dict[str, Any] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            retcols (iterable[str]): the names of the columns to return
+            order (str): order the select by this column
+            start (int): start number to begin the query from
+            limit (int): number of rows to reterive
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+        """
+        return self.runInteraction(
+            desc,
+            self._simple_select_list_paginate_txn,
+            table, keyvalues, pagevalues, retcols
+        )
+
+    @classmethod
+    def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
+        """Executes a SELECT query on the named table with start and limit,
+        of row numbers, which may return zero or number of rows from start to limit,
+        returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            keyvalues (dict[str, T] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            pagevalues ([]):
+                order (str): order the select by this column
+                start (int): start number to begin the query from
+                limit (int): number of rows to reterive
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]]
+
+        """
+        if keyvalues:
+            sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
+                ", ".join(retcols),
+                table,
+                " AND ".join("%s = ?" % (k,) for k in keyvalues),
+                " ? ASC LIMIT ? OFFSET ?"
+            )
+            txn.execute(sql, keyvalues.values() + pagevalues)
+        else:
+            sql = "SELECT %s FROM %s ORDER BY %s" % (
+                ", ".join(retcols),
+                table,
+                " ? ASC LIMIT ? OFFSET ?"
+            )
+            txn.execute(sql, pagevalues)
+
+        return cls.cursor_to_dict(txn)
+
+    @defer.inlineCallbacks
+    def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
+                               desc="get_user_list_paginate"):
+        """Get a list of users from start row to a limit number of rows. This will
+        return a json object with users and total number of users in users list.
+
+        Args:
+            table (str): the table name
+            keyvalues (dict[str, Any] | None):
+                column names and values to select the rows with, or None to not
+                apply a WHERE clause.
+            pagevalues ([]):
+                order (str): order the select by this column
+                start (int): start number to begin the query from
+                limit (int): number of rows to reterive
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+        """
+        users = yield self.runInteraction(
+            desc,
+            self._simple_select_list_paginate_txn,
+            table, keyvalues, pagevalues, retcols
+        )
+        count = yield self.runInteraction(
+            desc,
+            self.get_user_count_txn
+        )
+        retval = {
+            "users": users,
+            "total": count
+        }
+        defer.returnValue(retval)
+
+    def get_user_count_txn(self, txn):
+        """Get a total number of registerd users in the users list.
+
+        Args:
+            txn : Transaction object
+        Returns:
+            defer.Deferred: resolves to int
+        """
+        sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
+        txn.execute(sql_count)
+        count = txn.fetchone()[0]
+        defer.returnValue(count)
+
+    def _simple_search_list(self, table, term, col, retcols,
+                            desc="_simple_search_list"):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            table (str): the table name
+            term (str | None):
+                term for searching the table matched to a column.
+            col (str): column to query term should be matched to
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]] or None
+        """
+
+        return self.runInteraction(
+            desc,
+            self._simple_search_list_txn,
+            table, term, col, retcols
+        )
+
+    @classmethod
+    def _simple_search_list_txn(cls, txn, table, term, col, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Args:
+            txn : Transaction object
+            table (str): the table name
+            term (str | None):
+                term for searching the table matched to a column.
+            col (str): column to query term should be matched to
+            retcols (iterable[str]): the names of the columns to return
+        Returns:
+            defer.Deferred: resolves to list[dict[str, Any]] or None
+        """
+        if term:
+            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
+                ", ".join(retcols),
+                table,
+                col
+            )
+            termvalues = ["%%" + term + "%%"]
+            txn.execute(sql, termvalues)
+        else:
+            return 0
+
+        return cls.cursor_to_dict(txn)
+
 
 class _RollbackButIsFineException(Exception):
     """ This exception is used to rollback a transaction without implying
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2040e022fa..b9f1365f92 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -93,7 +93,7 @@ class EndToEndKeyStore(SQLBaseStore):
             query_clause = "user_id = ?"
             query_params.append(user_id)
 
-            if device_id:
+            if device_id is not None:
                 query_clause += " AND device_id = ?"
                 query_params.append(device_id)
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6685b9da1c..c88f689d3a 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -28,6 +28,7 @@ from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
 from synapse.state import resolve_events
+from synapse.util.caches.descriptors import cached
 
 from canonicaljson import encode_canonical_json
 from collections import deque, namedtuple, OrderedDict
@@ -301,7 +302,7 @@ class EventsStore(SQLBaseStore):
                                 room_id
                             )
                             new_latest_event_ids = yield self._calculate_new_extremeties(
-                                room_id, [ev for ev, _ in ev_ctx_rm]
+                                room_id, ev_ctx_rm, latest_event_ids
                             )
 
                             if new_latest_event_ids == set(latest_event_ids):
@@ -328,27 +329,24 @@ class EventsStore(SQLBaseStore):
                 persist_event_counter.inc_by(len(chunk))
 
     @defer.inlineCallbacks
-    def _calculate_new_extremeties(self, room_id, events):
+    def _calculate_new_extremeties(self, room_id, event_contexts, latest_event_ids):
         """Calculates the new forward extremeties for a room given events to
         persist.
 
         Assumes that we are only persisting events for one room at a time.
         """
-        latest_event_ids = yield self.get_latest_event_ids_in_room(
-            room_id
-        )
         new_latest_event_ids = set(latest_event_ids)
         # First, add all the new events to the list
         new_latest_event_ids.update(
-            event.event_id for event in events
-            if not event.internal_metadata.is_outlier()
+            event.event_id for event, ctx in event_contexts
+            if not event.internal_metadata.is_outlier() and not ctx.rejected
         )
         # Now remove all events that are referenced by the to-be-added events
         new_latest_event_ids.difference_update(
             e_id
-            for event in events
+            for event, ctx in event_contexts
             for e_id, _ in event.prev_events
-            if not event.internal_metadata.is_outlier()
+            if not event.internal_metadata.is_outlier() and not ctx.rejected
         )
 
         # And finally remove any events that are referenced by previously added
@@ -572,14 +570,6 @@ class EventsStore(SQLBaseStore):
                     txn, self.get_users_in_room, (room_id,)
                 )
 
-                # Add an entry to the current_state_resets table to record the point
-                # where we clobbered the current state
-                self._simple_insert_txn(
-                    txn,
-                    table="current_state_resets",
-                    values={"event_stream_ordering": max_stream_order}
-                )
-
         for room_id, new_extrem in new_forward_extremeties.items():
             self._simple_delete_txn(
                 txn,
@@ -1579,6 +1569,7 @@ class EventsStore(SQLBaseStore):
         """The current minimum token that backfilled events have reached"""
         return -self._backfill_id_gen.get_current_token()
 
+    @cached(num_args=5, max_entries=10)
     def get_all_new_events(self, last_backfill_id, last_forward_id,
                            current_backfill_id, current_forward_id, limit):
         """Get all the new events that have arrived at the server either as
@@ -1611,15 +1602,6 @@ class EventsStore(SQLBaseStore):
                     upper_bound = current_forward_id
 
                 sql = (
-                    "SELECT event_stream_ordering FROM current_state_resets"
-                    " WHERE ? < event_stream_ordering"
-                    " AND event_stream_ordering <= ?"
-                    " ORDER BY event_stream_ordering ASC"
-                )
-                txn.execute(sql, (last_forward_id, upper_bound))
-                state_resets = txn.fetchall()
-
-                sql = (
                     "SELECT event_stream_ordering, event_id, state_group"
                     " FROM ex_outlier_stream"
                     " WHERE ? > event_stream_ordering"
@@ -1630,7 +1612,6 @@ class EventsStore(SQLBaseStore):
                 forward_ex_outliers = txn.fetchall()
             else:
                 new_forward_events = []
-                state_resets = []
                 forward_ex_outliers = []
 
             sql = (
@@ -1670,7 +1651,6 @@ class EventsStore(SQLBaseStore):
             return AllNewEventsResult(
                 new_forward_events, new_backfill_events,
                 forward_ex_outliers, backward_ex_outliers,
-                state_resets,
             )
         return self.runInteraction("get_all_new_events", get_all_new_events_txn)
 
@@ -1896,5 +1876,4 @@ class EventsStore(SQLBaseStore):
 AllNewEventsResult = namedtuple("AllNewEventsResult", [
     "new_forward_events", "new_backfill_events",
     "forward_ex_outliers", "backward_ex_outliers",
-    "state_resets"
 ])
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 10f7c7a4bc..545d3d3a99 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore):
         )
 
         for event in events:
-            txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
-            txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
             txn.call_after(
                 self._membership_stream_cache.entity_has_changed,
                 event.state_key, event.internal_metadata.stream_ordering
@@ -131,7 +129,7 @@ class RoomMemberStore(SQLBaseStore):
         with self._stream_id_gen.get_next() as stream_ordering:
             yield self.runInteraction("locally_reject_invite", f, stream_ordering)
 
-    @cached(max_entries=100000, iterable=True)
+    @cached(max_entries=500000, iterable=True)
     def get_users_in_room(self, room_id):
         def f(txn):
 
@@ -266,7 +264,7 @@ class RoomMemberStore(SQLBaseStore):
             " ON m.event_id = c.event_id "
             " AND m.room_id = c.room_id "
             " AND m.user_id = c.state_key"
-            " WHERE %(where)s"
+            " WHERE c.type = 'm.room.member' AND %(where)s"
         ) % {
             "where": where_clause,
         }
@@ -276,12 +274,29 @@ class RoomMemberStore(SQLBaseStore):
 
         return rows
 
-    @cached(max_entries=5000)
+    @cached(max_entries=500000, iterable=True)
     def get_rooms_for_user(self, user_id):
         return self.get_rooms_for_user_where_membership_is(
             user_id, membership_list=[Membership.JOIN],
         )
 
+    @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
+    def get_users_who_share_room_with_user(self, user_id, cache_context):
+        """Returns the set of users who share a room with `user_id`
+        """
+        rooms = yield self.get_rooms_for_user(
+            user_id, on_invalidate=cache_context.invalidate,
+        )
+
+        user_who_share_room = set()
+        for room in rooms:
+            user_ids = yield self.get_users_in_room(
+                room.room_id, on_invalidate=cache_context.invalidate,
+            )
+            user_who_share_room.update(user_ids)
+
+        defer.returnValue(user_who_share_room)
+
     def forget(self, user_id, room_id):
         """Indicate that user_id wishes to discard history for room_id."""
         def f(txn):
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 2dc24951c4..200d124632 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -244,6 +244,20 @@ class StreamStore(SQLBaseStore):
 
         defer.returnValue(results)
 
+    def get_rooms_that_changed(self, room_ids, from_key):
+        """Given a list of rooms and a token, return rooms where there may have
+        been changes.
+
+        Args:
+            room_ids (list)
+            from_key (str): The room_key portion of a StreamToken
+        """
+        from_key = RoomStreamToken.parse_stream_token(from_key).stream
+        return set(
+            room_id for room_id in room_ids
+            if self._events_stream_cache.has_entity_changed(room_id, from_key)
+        )
+
     @defer.inlineCallbacks
     def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
                                         order='DESC'):