summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-08-26 14:54:30 +0100
committerErik Johnston <erik@matrix.org>2016-08-26 14:54:30 +0100
commitbed10f9880068306be3fcdd15a51b1712c6159f2 (patch)
tree5a7fee7675320aff0983da1e5d787de88026fd9a
parentMerge pull request #1048 from matrix-org/erikj/fix_mail_name (diff)
downloadsynapse-bed10f9880068306be3fcdd15a51b1712c6159f2.tar.xz
Use state handler instead of get_users_in_room/get_joined_hosts
Diffstat (limited to '')
-rw-r--r--synapse/federation/federation_client.py5
-rw-r--r--synapse/handlers/directory.py8
-rw-r--r--synapse/handlers/events.py3
-rw-r--r--synapse/handlers/presence.py13
-rw-r--r--synapse/handlers/receipts.py5
-rw-r--r--synapse/handlers/sync.py3
-rw-r--r--synapse/handlers/typing.py9
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py2
-rw-r--r--synapse/replication/slave/storage/events.py2
-rw-r--r--synapse/state.py9
-rw-r--r--synapse/storage/events.py1
-rw-r--r--synapse/storage/roommember.py11
12 files changed, 44 insertions, 27 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index f2b3aceb49..67ad3dfd37 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -29,6 +29,7 @@ from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from synapse.events import FrozenEvent
+from synapse.types import get_domain_from_id
 import synapse.metrics
 
 from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@@ -63,6 +64,7 @@ class FederationClient(FederationBase):
         self._clock.looping_call(
             self._clear_tried_cache, 60 * 1000,
         )
+        self.state = hs.get_state_handler()
 
     def _clear_tried_cache(self):
         """Clear pdu_destination_tried cache"""
@@ -811,7 +813,8 @@ class FederationClient(FederationBase):
         if len(signed_events) >= limit:
             defer.returnValue(signed_events)
 
-        servers = yield self.store.get_joined_hosts_for_room(room_id)
+        users = yield self.state.get_current_user_in_room(room_id)
+        servers = set(get_domain_from_id(u) for u in users)
 
         servers = set(servers)
         servers.discard(self.server_name)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 4bea7f2b19..14352985e2 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -19,7 +19,7 @@ from ._base import BaseHandler
 
 from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
 from synapse.api.constants import EventTypes
-from synapse.types import RoomAlias, UserID
+from synapse.types import RoomAlias, UserID, get_domain_from_id
 
 import logging
 import string
@@ -55,7 +55,8 @@ class DirectoryHandler(BaseHandler):
         # TODO(erikj): Add transactions.
         # TODO(erikj): Check if there is a current association.
         if not servers:
-            servers = yield self.store.get_joined_hosts_for_room(room_id)
+            users = yield self.state.get_current_user_in_room(room_id)
+            servers = set(get_domain_from_id(u) for u in users)
 
         if not servers:
             raise SynapseError(400, "Failed to get server list")
@@ -193,7 +194,8 @@ class DirectoryHandler(BaseHandler):
                 Codes.NOT_FOUND
             )
 
-        extra_servers = yield self.store.get_joined_hosts_for_room(room_id)
+        users = yield self.state.get_current_user_in_room(room_id)
+        extra_servers = set(get_domain_from_id(u) for u in users)
         servers = set(extra_servers) | set(servers)
 
         # If this server is in the list of servers, return it first.
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 3a3a1257d3..d3685fb12a 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -47,6 +47,7 @@ class EventStreamHandler(BaseHandler):
         self.clock = hs.get_clock()
 
         self.notifier = hs.get_notifier()
+        self.state = hs.get_state_handler()
 
     @defer.inlineCallbacks
     @log_function
@@ -90,7 +91,7 @@ class EventStreamHandler(BaseHandler):
                     # Send down presence.
                     if event.state_key == auth_user_id:
                         # Send down presence for everyone in the room.
-                        users = yield self.store.get_users_in_room(event.room_id)
+                        users = yield self.state.get_current_user_in_room(event.room_id)
                         states = yield presence_handler.get_states(
                             users,
                             as_event=True,
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6a1fe76c88..73752b2f89 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -88,6 +88,8 @@ class PresenceHandler(object):
         self.notifier = hs.get_notifier()
         self.federation = hs.get_replication_layer()
 
+        self.state = hs.get_state_handler()
+
         self.federation.register_edu_handler(
             "m.presence", self.incoming_presence
         )
@@ -532,7 +534,9 @@ class PresenceHandler(object):
                 if not local_states:
                     continue
 
-                hosts = yield self.store.get_joined_hosts_for_room(room_id)
+                users = yield self.state.get_current_user_in_room(room_id)
+                hosts = set(get_domain_from_id(u) for u in users)
+
                 for host in hosts:
                     hosts_to_states.setdefault(host, []).extend(local_states)
 
@@ -725,13 +729,13 @@ class PresenceHandler(object):
         # don't need to send to local clients here, as that is done as part
         # of the event stream/sync.
         # TODO: Only send to servers not already in the room.
+        user_ids = yield self.state.get_current_user_in_room(room_id)
         if self.is_mine(user):
             state = yield self.current_state_for_user(user.to_string())
 
-            hosts = yield self.store.get_joined_hosts_for_room(room_id)
+            hosts = set(get_domain_from_id(u) for u in user_ids)
             self._push_to_remotes({host: (state,) for host in hosts})
         else:
-            user_ids = yield self.store.get_users_in_room(room_id)
             user_ids = filter(self.is_mine_id, user_ids)
 
             states = yield self.current_state_for_users(user_ids)
@@ -955,6 +959,7 @@ class PresenceEventSource(object):
         self.get_presence_handler = hs.get_presence_handler
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
 
     @defer.inlineCallbacks
     @log_function
@@ -1017,7 +1022,7 @@ class PresenceEventSource(object):
 
                 user_ids_to_check = set()
                 for room_id in room_ids:
-                    users = yield self.store.get_users_in_room(room_id)
+                    users = yield self.state.get_current_user_in_room(room_id)
                     user_ids_to_check.update(users)
 
                 user_ids_to_check.update(friends)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index e62722d78d..726f7308d2 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,6 +18,7 @@ from ._base import BaseHandler
 from twisted.internet import defer
 
 from synapse.util.logcontext import PreserveLoggingContext
+from synapse.types import get_domain_from_id
 
 import logging
 
@@ -37,6 +38,7 @@ class ReceiptsHandler(BaseHandler):
             "m.receipt", self._received_remote_receipt
         )
         self.clock = self.hs.get_clock()
+        self.state = hs.get_state_handler()
 
     @defer.inlineCallbacks
     def received_client_receipt(self, room_id, receipt_type, user_id,
@@ -133,7 +135,8 @@ class ReceiptsHandler(BaseHandler):
             event_ids = receipt["event_ids"]
             data = receipt["data"]
 
-            remotedomains = yield self.store.get_joined_hosts_for_room(room_id)
+            users = yield self.state.get_current_user_in_room(room_id)
+            remotedomains = set(get_domain_from_id(u) for u in users)
             remotedomains = remotedomains.copy()
             remotedomains.discard(self.server_name)
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 5cd009a1c8..3017bc737b 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -139,6 +139,7 @@ class SyncHandler(object):
         self.event_sources = hs.get_event_sources()
         self.clock = hs.get_clock()
         self.response_cache = ResponseCache(hs)
+        self.state = hs.get_state_handler()
 
     def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
                                full_state=False):
@@ -630,7 +631,7 @@ class SyncHandler(object):
 
         extra_users_ids = set(newly_joined_users)
         for room_id in newly_joined_rooms:
-            users = yield self.store.get_users_in_room(room_id)
+            users = yield self.state.get_current_user_in_room(room_id)
             extra_users_ids.update(users)
         extra_users_ids.discard(user.to_string())
 
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 46181984c0..0b530b9034 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -20,7 +20,7 @@ from synapse.util.logcontext import (
     PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
 )
 from synapse.util.metrics import Measure
-from synapse.types import UserID
+from synapse.types import UserID, get_domain_from_id
 
 import logging
 
@@ -42,6 +42,7 @@ class TypingHandler(object):
         self.auth = hs.get_auth()
         self.is_mine_id = hs.is_mine_id
         self.notifier = hs.get_notifier()
+        self.state = hs.get_state_handler()
 
         self.clock = hs.get_clock()
 
@@ -166,7 +167,8 @@ class TypingHandler(object):
 
     @defer.inlineCallbacks
     def _push_update(self, room_id, user_id, typing):
-        domains = yield self.store.get_joined_hosts_for_room(room_id)
+        users = yield self.state.get_current_user_in_room(room_id)
+        domains = set(get_domain_from_id(u) for u in users)
 
         deferreds = []
         for domain in domains:
@@ -199,7 +201,8 @@ class TypingHandler(object):
         # Check that the string is a valid user id
         UserID.from_string(user_id)
 
-        domains = yield self.store.get_joined_hosts_for_room(room_id)
+        users = yield self.state.get_current_user_in_room(room_id)
+        domains = set(get_domain_from_id(u) for u in users)
 
         if self.server_name in domains:
             self._push_update_local(
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 8d49beaec5..51cb21ee9d 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -87,7 +87,7 @@ class BulkPushRuleEvaluator:
         )
 
         room_members = yield self.store.get_joined_users_from_context(
-            event.room_id, context,
+            event.room_id, context.state_group, context.current_state_ids
         )
 
         evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 65e982a0ce..00ad06fa4d 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -216,7 +216,6 @@ class SlavedEventStore(BaseSlavedStore):
             self._get_current_state_for_key.invalidate_all()
             self.get_rooms_for_user.invalidate_all()
             self.get_users_in_room.invalidate((event.room_id,))
-            # self.get_joined_hosts_for_room.invalidate((event.room_id,))
 
         self._invalidate_get_event_cache(event.event_id)
 
@@ -240,7 +239,6 @@ class SlavedEventStore(BaseSlavedStore):
 
         if event.type == EventTypes.Member:
             self.get_rooms_for_user.invalidate((event.state_key,))
-            # self.get_joined_hosts_for_room.invalidate((event.room_id,))
             self.get_users_in_room.invalidate((event.room_id,))
             self._membership_stream_cache.entity_has_changed(
                 event.state_key, event.internal_metadata.stream_ordering
diff --git a/synapse/state.py b/synapse/state.py
index 78461215ca..daec983dc9 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -125,6 +125,15 @@ class StateHandler(object):
         defer.returnValue(state)
 
     @defer.inlineCallbacks
+    def get_current_user_in_room(self, room_id):
+        latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+        group, state_ids = yield self.resolve_state_groups(room_id, latest_event_ids)
+        joined_users = yield self.store.get_joined_users_from_context(
+            room_id, group, state_ids
+        )
+        defer.returnValue(joined_users)
+
+    @defer.inlineCallbacks
     def compute_event_context(self, event, old_state=None):
         """ Fills out the context with the `current state` of the graph. The
         `current state` here is defined to be the state of the event graph
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 57e5005285..5cbe8c5978 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -393,7 +393,6 @@ class EventsStore(SQLBaseStore):
             txn.call_after(self._get_current_state_for_key.invalidate_all)
             txn.call_after(self.get_rooms_for_user.invalidate_all)
             txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
-            txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
 
             # Add an entry to the current_state_resets table to record the point
             # where we clobbered the current state
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 5f15200c20..cab1660830 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -56,7 +56,6 @@ class RoomMemberStore(SQLBaseStore):
 
         for event in events:
             txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
-            txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
             txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
             txn.call_after(
                 self._membership_stream_cache.entity_has_changed,
@@ -238,11 +237,6 @@ class RoomMemberStore(SQLBaseStore):
 
         return results
 
-    @cachedInlineCallbacks(max_entries=5000)
-    def get_joined_hosts_for_room(self, room_id):
-        user_ids = yield self.get_users_in_room(room_id)
-        defer.returnValue(set(get_domain_from_id(uid) for uid in user_ids))
-
     def _get_members_rows_txn(self, txn, room_id, membership=None, user_id=None):
         where_clause = "c.room_id = ?"
         where_values = [room_id]
@@ -360,8 +354,7 @@ class RoomMemberStore(SQLBaseStore):
             desc="who_forgot"
         )
 
-    def get_joined_users_from_context(self, room_id, context):
-        state_group = context.state_group
+    def get_joined_users_from_context(self, room_id, state_group, state_ids):
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
             # state group, i.e. we need to make sure that calls with a state_group
@@ -370,7 +363,7 @@ class RoomMemberStore(SQLBaseStore):
             state_group = object()
 
         return self._get_joined_users_from_context(
-            room_id, state_group, context.current_state_ids
+            room_id, state_group, state_ids
         )
 
     @cachedInlineCallbacks(num_args=2, cache_context=True)