diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 183a752387..90d7aee94a 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -629,6 +629,78 @@ class SQLBaseStore(object):
return self.cursor_to_dict(txn)
+ @defer.inlineCallbacks
+ def _simple_select_many_batch(self, table, column, iterable, retcols,
+ keyvalues={}, desc="_simple_select_many_batch",
+ batch_size=100):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ results = []
+
+ if not iterable:
+ defer.returnValue(results)
+
+ chunks = [iterable[i:i+batch_size] for i in xrange(0, len(iterable), batch_size)]
+ for chunk in chunks:
+ rows = yield self.runInteraction(
+ desc,
+ self._simple_select_many_txn,
+ table, column, chunk, keyvalues, retcols
+ )
+
+ results.extend(rows)
+
+ defer.returnValue(results)
+
+ def _simple_select_many_txn(self, txn, table, column, iterable, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ if not iterable:
+ return []
+
+ sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+
+ clauses = []
+ values = []
+ clauses.append(
+ "%s IN (%s)" % (column, ",".join("?" for _ in iterable))
+ )
+ values.extend(iterable)
+
+ for key, value in keyvalues.items():
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ if clauses:
+ sql = "%s WHERE %s" % (
+ sql,
+ " AND ".join(clauses),
+ )
+
+ txn.execute(sql, values)
+ return self.cursor_to_dict(txn)
+
def _simple_update_one(self, table, keyvalues, updatevalues,
desc="_simple_update_one"):
"""Executes an UPDATE query on the named table, setting new values for
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 1095d52ace..9b3aecaf8c 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -48,24 +48,25 @@ class PresenceStore(SQLBaseStore):
desc="get_presence_state",
)
- @cachedList(get_presence_state.cache, list_name="user_localparts")
+ @cachedList(get_presence_state.cache, list_name="user_localparts",
+ inlineCallbacks=True)
def get_presence_states(self, user_localparts):
- def f(txn):
- results = {}
- for user_localpart in user_localparts:
- res = self._simple_select_one_txn(
- txn,
- table="presence",
- keyvalues={"user_id": user_localpart},
- retcols=["state", "status_msg", "mtime"],
- allow_none=True,
- )
- if res:
- results[user_localpart] = res
-
- return results
-
- return self.runInteraction("get_presence_states", f)
+ rows = yield self._simple_select_many_batch(
+ table="presence",
+ column="user_id",
+ iterable=user_localparts,
+ retcols=("user_id", "state", "status_msg", "mtime",),
+ desc="get_presence_states",
+ )
+
+ defer.returnValue({
+ row["user_id"]: {
+ "state": row["state"],
+ "status_msg": row["status_msg"],
+ "mtime": row["mtime"],
+ }
+ for row in rows
+ })
def set_presence_state(self, user_localpart, new_state):
res = self._simple_update_one(
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 35ec7e8cef..1f51c90ee5 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -65,32 +65,20 @@ class PushRuleStore(SQLBaseStore):
if not user_ids:
defer.returnValue({})
- batch_size = 100
-
- def f(txn, user_ids_to_fetch):
- sql = (
- "SELECT pr.*"
- " FROM push_rules AS pr"
- " LEFT JOIN push_rules_enable AS pre"
- " ON pr.user_name = pre.user_name AND pr.rule_id = pre.rule_id"
- " WHERE pr.user_name"
- " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
- " AND (pre.enabled IS NULL OR pre.enabled = 1)"
- " ORDER BY pr.user_name, pr.priority_class DESC, pr.priority DESC"
- )
- txn.execute(sql, user_ids_to_fetch)
- return self.cursor_to_dict(txn)
-
results = {}
- chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
- for batch_user_ids in chunks:
- rows = yield self.runInteraction(
- "bulk_get_push_rules", f, batch_user_ids
- )
+ rows = yield self._simple_select_many_batch(
+ table="push_rules",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("*",),
+ desc="bulk_get_push_rules",
+ )
+
+ rows.sort(key=lambda e: (-e["priority_class"], -e["priority"]))
- for row in rows:
- results.setdefault(row['user_name'], []).append(row)
+ for row in rows:
+ results.setdefault(row['user_name'], []).append(row)
defer.returnValue(results)
@defer.inlineCallbacks
@@ -98,28 +86,17 @@ class PushRuleStore(SQLBaseStore):
if not user_ids:
defer.returnValue({})
- batch_size = 100
-
- def f(txn, user_ids_to_fetch):
- sql = (
- "SELECT user_name, rule_id, enabled"
- " FROM push_rules_enable"
- " WHERE user_name"
- " IN (" + ",".join("?" for _ in user_ids_to_fetch) + ")"
- )
- txn.execute(sql, user_ids_to_fetch)
- return self.cursor_to_dict(txn)
-
results = {}
- chunks = [user_ids[i:i+batch_size] for i in xrange(0, len(user_ids), batch_size)]
- for batch_user_ids in chunks:
- rows = yield self.runInteraction(
- "bulk_get_push_rules_enabled", f, batch_user_ids
- )
-
- for row in rows:
- results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
+ rows = yield self._simple_select_many_batch(
+ table="push_rules_enable",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("user_name", "rule_id", "enabled",),
+ desc="bulk_get_push_rules_enabled",
+ )
+ for row in rows:
+ results.setdefault(row['user_name'], {})[row['rule_id']] = row['enabled']
defer.returnValue(results)
@defer.inlineCallbacks
|