diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index fff6652e05..908551d6d9 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,12 +16,13 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.storage._base import SQLBaseStore
from synapse.storage.search import SearchStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
import collections
import logging
-import ujson as json
+import simplejson as json
import re
logger = logging.getLogger(__name__)
@@ -38,7 +39,138 @@ RatelimitOverride = collections.namedtuple(
)
-class RoomStore(SearchStore):
+class RoomWorkerStore(SQLBaseStore):
+ def get_public_room_ids(self):
+ return self._simple_select_onecol(
+ table="rooms",
+ keyvalues={
+ "is_public": True,
+ },
+ retcol="room_id",
+ desc="get_public_room_ids",
+ )
+
+ @cached(num_args=2, max_entries=100)
+ def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
+ """Get pulbic rooms for a particular list, or across all lists.
+
+ Args:
+ stream_id (int)
+ network_tuple (ThirdPartyInstanceID): The list to use (None, None)
+ means the main list, None means all lsits.
+ """
+ return self.runInteraction(
+ "get_public_room_ids_at_stream_id",
+ self.get_public_room_ids_at_stream_id_txn,
+ stream_id, network_tuple=network_tuple
+ )
+
+ def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
+ network_tuple):
+ return {
+ rm
+ for rm, vis in self.get_published_at_stream_id_txn(
+ txn, stream_id, network_tuple=network_tuple
+ ).items()
+ if vis
+ }
+
+ def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
+ if network_tuple:
+ # We want to get from a particular list. No aggregation required.
+
+ sql = ("""
+ SELECT room_id, visibility FROM public_room_list_stream
+ INNER JOIN (
+ SELECT room_id, max(stream_id) AS stream_id
+ FROM public_room_list_stream
+ WHERE stream_id <= ? %s
+ GROUP BY room_id
+ ) grouped USING (room_id, stream_id)
+ """)
+
+ if network_tuple.appservice_id is not None:
+ txn.execute(
+ sql % ("AND appservice_id = ? AND network_id = ?",),
+ (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
+ )
+ else:
+ txn.execute(
+ sql % ("AND appservice_id IS NULL",),
+ (stream_id,)
+ )
+ return dict(txn)
+ else:
+ # We want to get from all lists, so we need to aggregate the results
+
+ logger.info("Executing full list")
+
+ sql = ("""
+ SELECT room_id, visibility
+ FROM public_room_list_stream
+ INNER JOIN (
+ SELECT
+ room_id, max(stream_id) AS stream_id, appservice_id,
+ network_id
+ FROM public_room_list_stream
+ WHERE stream_id <= ?
+ GROUP BY room_id, appservice_id, network_id
+ ) grouped USING (room_id, stream_id)
+ """)
+
+ txn.execute(
+ sql,
+ (stream_id,)
+ )
+
+ results = {}
+ # A room is visible if its visible on any list.
+ for room_id, visibility in txn:
+ results[room_id] = bool(visibility) or results.get(room_id, False)
+
+ return results
+
+ def get_public_room_changes(self, prev_stream_id, new_stream_id,
+ network_tuple):
+ def get_public_room_changes_txn(txn):
+ then_rooms = self.get_public_room_ids_at_stream_id_txn(
+ txn, prev_stream_id, network_tuple
+ )
+
+ now_rooms_dict = self.get_published_at_stream_id_txn(
+ txn, new_stream_id, network_tuple
+ )
+
+ now_rooms_visible = set(
+ rm for rm, vis in now_rooms_dict.items() if vis
+ )
+ now_rooms_not_visible = set(
+ rm for rm, vis in now_rooms_dict.items() if not vis
+ )
+
+ newly_visible = now_rooms_visible - then_rooms
+ newly_unpublished = now_rooms_not_visible & then_rooms
+
+ return newly_visible, newly_unpublished
+
+ return self.runInteraction(
+ "get_public_room_changes", get_public_room_changes_txn
+ )
+
+ @cached(max_entries=10000)
+ def is_room_blocked(self, room_id):
+ return self._simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={
+ "room_id": room_id,
+ },
+ retcol="1",
+ allow_none=True,
+ desc="is_room_blocked",
+ )
+
+
+class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
@@ -225,16 +357,6 @@ class RoomStore(SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
- def get_public_room_ids(self):
- return self._simple_select_onecol(
- table="rooms",
- keyvalues={
- "is_public": True,
- },
- retcol="room_id",
- desc="get_public_room_ids",
- )
-
def get_room_count(self):
"""Retrieve a list of all rooms
"""
@@ -326,113 +448,6 @@ class RoomStore(SearchStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- @cached(num_args=2, max_entries=100)
- def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
- """Get pulbic rooms for a particular list, or across all lists.
-
- Args:
- stream_id (int)
- network_tuple (ThirdPartyInstanceID): The list to use (None, None)
- means the main list, None means all lsits.
- """
- return self.runInteraction(
- "get_public_room_ids_at_stream_id",
- self.get_public_room_ids_at_stream_id_txn,
- stream_id, network_tuple=network_tuple
- )
-
- def get_public_room_ids_at_stream_id_txn(self, txn, stream_id,
- network_tuple):
- return {
- rm
- for rm, vis in self.get_published_at_stream_id_txn(
- txn, stream_id, network_tuple=network_tuple
- ).items()
- if vis
- }
-
- def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
- if network_tuple:
- # We want to get from a particular list. No aggregation required.
-
- sql = ("""
- SELECT room_id, visibility FROM public_room_list_stream
- INNER JOIN (
- SELECT room_id, max(stream_id) AS stream_id
- FROM public_room_list_stream
- WHERE stream_id <= ? %s
- GROUP BY room_id
- ) grouped USING (room_id, stream_id)
- """)
-
- if network_tuple.appservice_id is not None:
- txn.execute(
- sql % ("AND appservice_id = ? AND network_id = ?",),
- (stream_id, network_tuple.appservice_id, network_tuple.network_id,)
- )
- else:
- txn.execute(
- sql % ("AND appservice_id IS NULL",),
- (stream_id,)
- )
- return dict(txn)
- else:
- # We want to get from all lists, so we need to aggregate the results
-
- logger.info("Executing full list")
-
- sql = ("""
- SELECT room_id, visibility
- FROM public_room_list_stream
- INNER JOIN (
- SELECT
- room_id, max(stream_id) AS stream_id, appservice_id,
- network_id
- FROM public_room_list_stream
- WHERE stream_id <= ?
- GROUP BY room_id, appservice_id, network_id
- ) grouped USING (room_id, stream_id)
- """)
-
- txn.execute(
- sql,
- (stream_id,)
- )
-
- results = {}
- # A room is visible if its visible on any list.
- for room_id, visibility in txn:
- results[room_id] = bool(visibility) or results.get(room_id, False)
-
- return results
-
- def get_public_room_changes(self, prev_stream_id, new_stream_id,
- network_tuple):
- def get_public_room_changes_txn(txn):
- then_rooms = self.get_public_room_ids_at_stream_id_txn(
- txn, prev_stream_id, network_tuple
- )
-
- now_rooms_dict = self.get_published_at_stream_id_txn(
- txn, new_stream_id, network_tuple
- )
-
- now_rooms_visible = set(
- rm for rm, vis in now_rooms_dict.items() if vis
- )
- now_rooms_not_visible = set(
- rm for rm, vis in now_rooms_dict.items() if not vis
- )
-
- newly_visible = now_rooms_visible - then_rooms
- newly_unpublished = now_rooms_not_visible & then_rooms
-
- return newly_visible, newly_unpublished
-
- return self.runInteraction(
- "get_public_room_changes", get_public_room_changes_txn
- )
-
def get_all_new_public_rooms(self, prev_id, current_id, limit):
def get_all_new_public_rooms(txn):
sql = ("""
@@ -482,18 +497,6 @@ class RoomStore(SearchStore):
else:
defer.returnValue(None)
- @cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self._simple_select_one_onecol(
- table="blocked_rooms",
- keyvalues={
- "room_id": room_id,
- },
- retcol="1",
- allow_none=True,
- desc="is_room_blocked",
- )
-
@defer.inlineCallbacks
def block_room(self, room_id, user_id):
yield self._simple_insert(
@@ -504,7 +507,11 @@ class RoomStore(SearchStore):
},
desc="block_room",
)
- self.is_room_blocked.invalidate((room_id,))
+ yield self.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked, (room_id,),
+ )
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
|