summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/push/__init__.py30
-rw-r--r--synapse/storage/_base.py72
-rw-r--r--synapse/storage/presence.py35
-rw-r--r--synapse/storage/push_rule.py63
-rw-r--r--synapse/storage/roommember.py1
-rw-r--r--synapse/storage/schema/delta/28/event_push_actions.sql1
7 files changed, 126 insertions, 78 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 1e99c1303c..c11b98d0b7 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -52,7 +52,7 @@ class RegistrationHandler(BaseHandler):
         if urllib.quote(localpart.encode('utf-8')) != localpart:
             raise SynapseError(
                 400,
-                "User ID can only contain characters a-z, 0-9, or '-./'",
+                "User ID can only contain characters a-z, 0-9, or '_-./'",
                 Codes.INVALID_USERNAME
             )
 
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index e6a28bd8c0..9bc0b356f4 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -17,7 +17,6 @@ from twisted.internet import defer
 
 from synapse.streams.config import PaginationConfig
 from synapse.types import StreamToken
-from synapse.api.constants import Membership
 
 import synapse.util.async
 import push_rule_evaluator as push_rule_evaluator
@@ -296,31 +295,28 @@ class Pusher(object):
 
     @defer.inlineCallbacks
     def _get_badge_count(self):
-        room_list = yield self.store.get_rooms_for_user_where_membership_is(
-            user_id=self.user_id,
-            membership_list=(Membership.INVITE, Membership.JOIN)
-        )
+        invites, joins = yield defer.gatherResults([
+            self.store.get_invites_for_user(self.user_id),
+            self.store.get_rooms_for_user(self.user_id),
+        ], consumeErrors=True)
 
         my_receipts_by_room = yield self.store.get_receipts_for_user(
             self.user_id,
             "m.read",
         )
 
-        badge = 0
+        badge = len(invites)
 
-        for r in room_list:
-            if r.membership == Membership.INVITE:
-                badge += 1
-            else:
-                if r.room_id in my_receipts_by_room:
-                    last_unread_event_id = my_receipts_by_room[r.room_id]
+        for r in joins:
+            if r.room_id in my_receipts_by_room:
+                last_unread_event_id = my_receipts_by_room[r.room_id]
 
-                    notifs = yield (
-                        self.store.get_unread_event_push_actions_by_room_for_user(
-                            r.room_id, self.user_id, last_unread_event_id
-                        )
+                notifs = yield (
+                    self.store.get_unread_event_push_actions_by_room_for_user(
+                        r.room_id, self.user_id, last_unread_event_id
                     )
-                    badge += len(notifs)
+                )
+                badge += len(notifs)
         defer.returnValue(badge)
 
 
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 183a752387..90d7aee94a 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -629,6 +629,78 @@ class SQLBaseStore(object):
 
         return self.cursor_to_dict(txn)
 
+    @defer.inlineCallbacks
+    def _simple_select_many_batch(self, table, column, iterable, retcols,
+                                  keyvalues={}, desc="_simple_select_many_batch",
+                                  batch_size=100):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        results = []
+
+        if not iterable:
+            defer.returnValue(results)
+
+        chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)]
+        for chunk in chunks:
+            rows = yield self.runInteraction(
+                desc,
+                self._simple_select_many_txn,
+                table, column, chunk, keyvalues, retcols
+            )
+
+            results.extend(rows)
+
+        defer.returnValue(results)
+
+    def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
+        """Executes a SELECT query on the named table, which may return zero or
+        more rows, returning the result as a list of dicts.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            txn : Transaction object
+            table : string giving the table name
+            column : column name to test for inclusion against `iterable`
+            iterable : list
+            keyvalues : dict of column names and values to select the rows with
+            retcols : list of strings giving the names of the columns to return
+        """
+        if not iterable:
+            return []
+
+        sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+
+        clauses = []
+        values = []
+        clauses.append(
+            "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
+        )
+        values.extend(iterable)
+
+        for key, value in keyvalues.items():
+            clauses.append("%s = ?" % (key,))
+            values.append(value)
+
+        if clauses:
+            sql = "%s WHERE %s" % (
+                sql,
+                " AND ".join(clauses),
+            )
+
+        txn.execute(sql, values)
+        return self.cursor_to_dict(txn)
+
     def _simple_update_one(self, table, keyvalues, updatevalues,
                            desc="_simple_update_one"):
         """Executes an UPDATE query on the named table, setting new values for
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 1095d52ace..9b3aecaf8c 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -48,24 +48,25 @@ class PresenceStore(SQLBaseStore):
             desc="get_presence_state",
         )
 
-    @cachedList(get_presence_state.cache, list_name="user_localparts")
+    @cachedList(get_presence_state.cache, list_name="user_localparts",
+                inlineCallbacks=True)
     def get_presence_states(self, user_localparts):
-        def f(txn):
-            results = {}
-            for user_localpart in user_localparts:
-                res = self._simple_select_one_txn(
-                    txn,
-                    table="presence",
-                    keyvalues={"user_id": user_localpart},
-                    retcols=["state", "status_msg", "mtime"],
-                    allow_none=True,
-                )
-                if res:
-                    results[user_localpart] = res
-
-            return results
-
-        return self.runInteraction("get_presence_states", f)
+        rows = yield self._simple_select_many_batch(
+            table="presence",
+            column="user_id",
+            iterable=user_localparts,
+            retcols=("user_id", "state", "status_msg", "mtime",),
+            desc="get_presence_states",
+        )
+
+        defer.returnValue({
+            row["user_id"]: {
+                "state": row["state"],
+                "status_msg": row["status_msg"],
+                "mtime": row["mtime"],
+            }
+            for row in rows
+        })
 
     def set_presence_state(self, user_localpart, new_state):
         res = self._simple_update_one(
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 35ec7e8cef..1f51c90ee5 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -65,32 +65,20 @@ class PushRuleStore(SQLBaseStore):
         if not user_ids:
             defer.returnValue({})
 
-        batch_size = 100
-
-        def f(txn, user_ids_to_fetch):
-            sql = (
-                "SELECT pr.*"
-                " FROM push_rules AS pr"
-                " LEFT JOIN push_rules_enable AS pre"
-                " ON pr.user_name = pre.user_name AND pr.rule_id = pre.rule_id"
-                " WHERE pr.user_name"
-                " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
-                " AND (pre.enabled IS NULL OR pre.enabled = 1)"
-                " ORDER BY pr.user_name, pr.priority_class DESC, pr.priority DESC"
-            )
-            txn.execute(sql, user_ids_to_fetch)
-            return self.cursor_to_dict(txn)
-
         results = {}
 
-        chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
-        for batch_user_ids in chunks:
-            rows = yield self.runInteraction(
-                "bulk_get_push_rules", f, batch_user_ids
-            )
+        rows = yield self._simple_select_many_batch(
+            table="push_rules",
+            column="user_name",
+            iterable=user_ids,
+            retcols=("*",),
+            desc="bulk_get_push_rules",
+        )
+
+        rows.sort(key=lambda e: (-e["priority_class"], -e["priority"]))
 
-            for row in rows:
-                results.setdefault(row['user_name'], []).append(row)
+        for row in rows:
+            results.setdefault(row['user_name'], []).append(row)
         defer.returnValue(results)
 
     @defer.inlineCallbacks
@@ -98,28 +86,17 @@ class PushRuleStore(SQLBaseStore):
         if not user_ids:
             defer.returnValue({})
 
-        batch_size = 100
-
-        def f(txn, user_ids_to_fetch):
-            sql = (
-                "SELECT user_name, rule_id, enabled"
-                " FROM push_rules_enable"
-                " WHERE user_name"
-                " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
-            )
-            txn.execute(sql, user_ids_to_fetch)
-            return self.cursor_to_dict(txn)
-
         results = {}
 
-        chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
-        for batch_user_ids in chunks:
-            rows = yield self.runInteraction(
-                "bulk_get_push_rules_enabled", f, batch_user_ids
-            )
-
-            for row in rows:
-                results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
+        rows = yield self._simple_select_many_batch(
+            table="push_rules_enable",
+            column="user_name",
+            iterable=user_ids,
+            retcols=("user_name", "rule_id", "enabled",),
+            desc="bulk_get_push_rules_enabled",
+        )
+        for row in rows:
+            results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
         defer.returnValue(results)
 
     @defer.inlineCallbacks
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 68ac88905f..edfecced05 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -110,6 +110,7 @@ class RoomMemberStore(SQLBaseStore):
             membership=membership,
         ).addCallback(self._get_events)
 
+    @cached()
     def get_invites_for_user(self, user_id):
         """ Get all the invite events for a user
         Args:
diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/schema/delta/28/event_push_actions.sql
index bdf6ae3f24..4d519849df 100644
--- a/synapse/storage/schema/delta/28/event_push_actions.sql
+++ b/synapse/storage/schema/delta/28/event_push_actions.sql
@@ -24,3 +24,4 @@ CREATE TABLE IF NOT EXISTS event_push_actions(
 
 
 CREATE INDEX event_push_actions_room_id_event_id_user_id_profile_tag on event_push_actions(room_id, event_id, user_id, profile_tag);
+CREATE INDEX event_push_actions_room_id_user_id on event_push_actions(room_id, user_id);