summary refs log tree commit diff
path: root/synapse/storage/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r--synapse/storage/roommember.py95
1 files changed, 74 insertions, 21 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 0829ae5bee..7155bfdc69 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
 from collections import namedtuple
 
 from ._base import SQLBaseStore
+from synapse.util.async import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.stringutils import to_ascii
@@ -392,7 +393,8 @@ class RoomMemberStore(SQLBaseStore):
             context=context,
         )
 
-    def get_joined_users_from_state(self, room_id, state_group, state_ids):
+    def get_joined_users_from_state(self, room_id, state_entry):
+        state_group = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -401,7 +403,7 @@ class RoomMemberStore(SQLBaseStore):
             state_group = object()
 
         return self._get_joined_users_from_context(
-            room_id, state_group, state_ids,
+            room_id, state_group, state_entry.state, context=state_entry,
         )
 
     @cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
@@ -534,7 +536,8 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(False)
 
-    def get_joined_hosts(self, room_id, state_group, state_ids):
+    def get_joined_hosts(self, room_id, state_entry):
+        state_group = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -543,33 +546,20 @@ class RoomMemberStore(SQLBaseStore):
             state_group = object()
 
         return self._get_joined_hosts(
-            room_id, state_group, state_ids
+            room_id, state_group, state_entry.state, state_entry=state_entry,
         )
 
     @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
-    def _get_joined_hosts(self, room_id, state_group, current_state_ids):
+    # @defer.inlineCallbacks
+    def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
         # We don't use `state_group`, its there so that we can cache based
         # on it. However, its important that its never None, since two current_state's
         # with a state_group of None are likely to be different.
         # See bulk_get_push_rules_for_room for how we work around this.
         assert state_group is not None
 
-        joined_hosts = set()
-        for etype, state_key in current_state_ids:
-            if etype == EventTypes.Member:
-                try:
-                    host = get_domain_from_id(state_key)
-                except:
-                    logger.warn("state_key not user_id: %s", state_key)
-                    continue
-
-                if host in joined_hosts:
-                    continue
-
-                event_id = current_state_ids[(etype, state_key)]
-                event = yield self.get_event(event_id, allow_none=True)
-                if event and event.content["membership"] == Membership.JOIN:
-                    joined_hosts.add(intern_string(host))
+        cache = self._get_joined_hosts_cache(room_id)
+        joined_hosts = yield cache.get_destinations(state_entry)
 
         defer.returnValue(joined_hosts)
 
@@ -647,3 +637,66 @@ class RoomMemberStore(SQLBaseStore):
             yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
 
         defer.returnValue(result)
+
+    @cached(max_entries=10000, iterable=True)
+    def _get_joined_hosts_cache(self, room_id):
+        return _JoinedHostsCache(self, room_id)
+
+
+class _JoinedHostsCache(object):
+    def __init__(self, store, room_id):
+        self.store = store
+        self.room_id = room_id
+
+        self.hosts_to_joined_users = {}
+
+        self.state_group = object()
+
+        self.linearizer = Linearizer("_JoinedHostsCache")
+
+        self._len = 0
+
+    @defer.inlineCallbacks
+    def get_destinations(self, state_entry):
+        if state_entry.state_group == self.state_group:
+            defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+        with (yield self.linearizer.queue(())):
+            if state_entry.state_group == self.state_group:
+                pass
+            elif state_entry.prev_group == self.state_group:
+                for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+                    if typ != EventTypes.Member:
+                        continue
+
+                    host = intern_string(get_domain_from_id(state_key))
+                    user_id = state_key
+                    known_joins = self.hosts_to_joined_users.setdefault(host, set())
+
+                    event = yield self.store.get_event(event_id)
+                    if event.membership == Membership.JOIN:
+                        known_joins.add(user_id)
+                    else:
+                        known_joins.discard(user_id)
+
+                        if not known_joins:
+                            self.hosts_to_joined_users.pop(host, None)
+            else:
+                joined_users = yield self.store.get_joined_users_from_state(
+                    self.room_id, state_entry,
+                )
+
+                self.hosts_to_joined_users = {}
+                for user_id in joined_users:
+                    host = intern_string(get_domain_from_id(user_id))
+                    self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+            if state_entry.state_group:
+                self.state_group = state_entry.state_group
+            else:
+                self.state_group = object()
+            self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+        defer.returnValue(frozenset(self.hosts_to_joined_users))
+
+    def __len__(self):
+        return self._len