diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 7f2c08d7a6..34ed84ea22 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -157,6 +157,18 @@ class RoomWorkerStore(SQLBaseStore):
"get_public_room_changes", get_public_room_changes_txn
)
+ @cached(max_entries=10000)
+ def is_room_blocked(self, room_id):
+ return self._simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="1",
+ allow_none=True,
+ desc="is_room_blocked",
+ )
+
class RoomStore(RoomWorkerStore, SearchStore):
@@ -485,18 +497,6 @@ class RoomStore(RoomWorkerStore, SearchStore):
else:
defer.returnValue(None)
- @cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self._simple_select_one_onecol(
- table="blocked_rooms",
- keyvalues={
- "room_id": room_id,
- },
- retcol="1",
- allow_none=True,
- desc="is_room_blocked",
- )
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
yield self._simple_insert(
@@ -507,7 +507,11 @@ class RoomStore(RoomWorkerStore, SearchStore):
},
desc="block_room",
)
- self.is_room_blocked.invalidate((room_id,))
+ yield self.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked, (room_id,),
+ )
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 2b325e1c1f..ffa4246031 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -240,6 +240,9 @@ class StateGroupWorkerStore(SQLBaseStore):
(
"AND type = ? AND state_key = ?",
(etype, state_key)
+ ) if state_key is not None else (
+ "AND type = ?",
+ (etype,)
)
for etype, state_key in types
]
@@ -259,10 +262,19 @@ class StateGroupWorkerStore(SQLBaseStore):
key = (typ, state_key)
results[group][key] = event_id
else:
+ where_args = []
+ where_clauses = []
+ wildcard_types = False
if types is not None:
- where_clause = "AND (%s)" % (
- " OR ".join(["(type = ? AND state_key = ?)"] * len(types)),
- )
+ for typ in types:
+ if typ[1] is None:
+ where_clauses.append("(type = ?)")
+ where_args.extend(typ[0])
+ wildcard_types = True
+ else:
+ where_clauses.append("(type = ? AND state_key = ?)")
+ where_args.extend([typ[0], typ[1]])
+ where_clause = "AND (%s)" % (" OR ".join(where_clauses))
else:
where_clause = ""
@@ -279,7 +291,7 @@ class StateGroupWorkerStore(SQLBaseStore):
# after we finish deduping state, which requires this func)
args = [next_group]
if types:
- args.extend(i for typ in types for i in typ)
+ args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
@@ -292,9 +304,17 @@ class StateGroupWorkerStore(SQLBaseStore):
if (typ, state_key) not in results[group]
)
- # If the lengths match then we must have all the types,
- # so no need to go walk further down the tree.
- if types is not None and len(results[group]) == len(types):
+ # If the number of entries in the (type,state_key)->event_id dict
+ # matches the number of (type,state_keys) types we were searching
+ # for, then we must have found them all, so no need to go walk
+ # further down the tree... UNLESS our types filter contained
+ # wildcards (i.e. Nones) in which case we have to do an exhaustive
+ # search
+ if (
+ types is not None and
+ not wildcard_types and
+ len(results[group]) == len(types)
+ ):
break
next_group = self._simple_select_one_onecol_txn(
|