summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/room.py4
-rw-r--r--synapse/push/__init__.py27
-rw-r--r--synapse/storage/_base.py6
-rw-r--r--synapse/storage/events.py25
-rw-r--r--synapse/storage/push_rule.py21
-rw-r--r--synapse/storage/room.py3
-rw-r--r--synapse/storage/roommember.py2
-rw-r--r--synapse/storage/state.py25
-rw-r--r--synapse/util/lrucache.py8
9 files changed, 86 insertions, 35 deletions
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index cfa2e38ed2..3da08c147e 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -531,9 +531,7 @@ class RoomListHandler(BaseHandler):
         chunk = yield self.store.get_rooms(is_public=True)
         results = yield defer.gatherResults(
             [
-                self.store.get_users_in_room(
-                    room_id=room["room_id"],
-                )
+                self.store.get_users_in_room(room["room_id"])
                 for room in chunk
             ],
             consumeErrors=True,
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 5575c847f9..e3dd4ce76d 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -84,25 +84,20 @@ class Pusher(object):
 
         rules = baserules.list_with_base_rules(rawrules, user)
 
+        room_id = ev['room_id']
+
         # get *our* member event for display name matching
-        member_events_for_room = yield self.store.get_current_state(
-            room_id=ev['room_id'],
+        my_display_name = None
+        our_member_event = yield self.store.get_current_state(
+            room_id=room_id,
             event_type='m.room.member',
-            state_key=None
+            state_key=self.user_name,
         )
-        my_display_name = None
-        room_member_count = 0
-        for mev in member_events_for_room:
-            if mev.content['membership'] != 'join':
-                continue
-
-            # This loop does two things:
-            # 1) Find our current display name
-            if mev.state_key == self.user_name and 'displayname' in mev.content:
-                my_display_name = mev.content['displayname']
+        if our_member_event:
+            my_display_name = our_member_event[0].content.get("displayname")
 
-            # and 2) Get the number of people in that room
-            room_member_count += 1
+        room_members = yield self.store.get_users_in_room(room_id)
+        room_member_count = len(room_members)
 
         for r in rules:
             if r['rule_id'] in enabled_map:
@@ -287,9 +282,11 @@ class Pusher(object):
             if len(actions) == 0:
                 logger.warn("Empty actions! Using default action.")
                 actions = Pusher.DEFAULT_ACTIONS
+
             if 'notify' not in actions and 'dont_notify' not in actions:
                 logger.warn("Neither notify nor dont_notify in actions: adding default")
                 actions.extend(Pusher.DEFAULT_ACTIONS)
+
             if 'dont_notify' in actions:
                 logger.debug(
                     "%s for %s: dont_notify",
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 9e348590ba..c8c76e58fe 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -124,6 +124,11 @@ class Cache(object):
         self.sequence += 1
         self.cache.pop(keyargs, None)
 
+    def invalidate_all(self):
+        self.check_thread()
+        self.sequence += 1
+        self.cache.clear()
+
 
 def cached(max_entries=1000, num_args=1, lru=False):
     """ A method decorator that applies a memoizing cache around the function.
@@ -175,6 +180,7 @@ def cached(max_entries=1000, num_args=1, lru=False):
                 defer.returnValue(ret)
 
         wrapped.invalidate = cache.invalidate
+        wrapped.invalidate_all = cache.invalidate_all
         wrapped.prefill = cache.prefill
         return wrapped
 
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 38395c66ab..1304219e86 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -107,6 +107,12 @@ class EventsStore(SQLBaseStore):
         # We purposefully do this first since if we include a `current_state`
         # key, we *want* to update the `current_state_events` table
         if current_state:
+            txn.call_after(self.get_current_state_for_key.invalidate_all)
+            txn.call_after(self.get_rooms_for_user.invalidate_all)
+            txn.call_after(self.get_users_in_room.invalidate, event.room_id)
+            txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
+            txn.call_after(self.get_room_name_and_aliases, event.room_id)
+
             self._simple_delete_txn(
                 txn,
                 table="current_state_events",
@@ -114,13 +120,6 @@ class EventsStore(SQLBaseStore):
             )
 
             for s in current_state:
-                if s.type == EventTypes.Member:
-                    txn.call_after(
-                        self.get_rooms_for_user.invalidate, s.state_key
-                    )
-                    txn.call_after(
-                        self.get_joined_hosts_for_room.invalidate, s.room_id
-                    )
                 self._simple_insert_txn(
                     txn,
                     "current_state_events",
@@ -335,6 +334,18 @@ class EventsStore(SQLBaseStore):
             )
 
             if is_new_state and not context.rejected:
+                txn.call_after(
+                    self.get_current_state_for_key.invalidate,
+                    event.room_id, event.type, event.state_key
+                )
+
+                if (event.type == EventTypes.Name
+                        or event.type == EventTypes.Aliases):
+                    txn.call_after(
+                        self.get_room_name_and_aliases.invalidate,
+                        event.room_id
+                    )
+
                 self._simple_upsert_txn(
                     txn,
                     "current_state_events",
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 34805e276e..88ee21b089 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -13,9 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import collections
-
-from ._base import SQLBaseStore, Table
+from ._base import SQLBaseStore, cached
 from twisted.internet import defer
 
 import logging
@@ -41,6 +39,7 @@ class PushRuleStore(SQLBaseStore):
 
         defer.returnValue(rows)
 
+    @cached()
     @defer.inlineCallbacks
     def get_push_rules_enabled_for_user(self, user_name):
         results = yield self._simple_select_list(
@@ -151,6 +150,10 @@ class PushRuleStore(SQLBaseStore):
 
             txn.execute(sql, (user_name, priority_class, new_rule_priority))
 
+        txn.call_after(
+            self.get_push_rules_enabled_for_user.invalidate, user_name
+        )
+
         self._simple_insert_txn(
             txn,
             table=PushRuleTable.table_name,
@@ -179,6 +182,10 @@ class PushRuleStore(SQLBaseStore):
         new_rule['priority_class'] = priority_class
         new_rule['priority'] = new_prio
 
+        txn.call_after(
+            self.get_push_rules_enabled_for_user.invalidate, user_name
+        )
+
         self._simple_insert_txn(
             txn,
             table=PushRuleTable.table_name,
@@ -201,6 +208,7 @@ class PushRuleStore(SQLBaseStore):
             {'user_name': user_name, 'rule_id': rule_id},
             desc="delete_push_rule",
         )
+        self.get_push_rules_enabled_for_user.invalidate(user_name)
 
     @defer.inlineCallbacks
     def set_push_rule_enabled(self, user_name, rule_id, enabled):
@@ -210,6 +218,7 @@ class PushRuleStore(SQLBaseStore):
             {'enabled': 1 if enabled else 0},
             desc="set_push_rule_enabled",
         )
+        self.get_push_rules_enabled_for_user.invalidate(user_name)
 
 
 class RuleNotFoundException(Exception):
@@ -220,7 +229,7 @@ class InconsistentRuleException(Exception):
     pass
 
 
-class PushRuleTable(Table):
+class PushRuleTable(object):
     table_name = "push_rules"
 
     fields = [
@@ -233,10 +242,8 @@ class PushRuleTable(Table):
         "actions",
     ]
 
-    EntryType = collections.namedtuple("PushRuleEntry", fields)
-
 
-class PushRuleEnableTable(Table):
+class PushRuleEnableTable(object):
     table_name = "push_rules_enable"
 
     fields = [
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index f956377632..4612a8aa83 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -17,7 +17,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
 
 import collections
 import logging
@@ -186,6 +186,7 @@ class RoomStore(SQLBaseStore):
                 }
             )
 
+    @cached()
     @defer.inlineCallbacks
     def get_room_name_and_aliases(self, room_id):
         def f(txn):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 839c74f63a..3691eade05 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -66,6 +66,7 @@ class RoomMemberStore(SQLBaseStore):
 
         txn.call_after(self.get_rooms_for_user.invalidate, target_user_id)
         txn.call_after(self.get_joined_hosts_for_room.invalidate, event.room_id)
+        txn.call_after(self.get_users_in_room.invalidate, event.room_id)
 
     def get_room_member(self, user_id, room_id):
         """Retrieve the current state of a room member.
@@ -87,6 +88,7 @@ class RoomMemberStore(SQLBaseStore):
 
         return self.runInteraction("get_room_member", f)
 
+    @cached()
     def get_users_in_room(self, room_id):
         def f(txn):
 
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dbc0e49c1f..6df7350552 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
 
 from twisted.internet import defer
 
@@ -130,6 +130,12 @@ class StateStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     def get_current_state(self, room_id, event_type=None, state_key=""):
+        if event_type and state_key is not None:
+            result = yield self.get_current_state_for_key(
+                room_id, event_type, state_key
+            )
+            defer.returnValue(result)
+
         def f(txn):
             sql = (
                 "SELECT event_id FROM current_state_events"
@@ -153,6 +159,23 @@ class StateStore(SQLBaseStore):
         events = yield self.runInteraction("get_current_state", f)
         defer.returnValue(events)
 
+    @cached(num_args=3)
+    @defer.inlineCallbacks
+    def get_current_state_for_key(self, room_id, event_type, state_key):
+        def f(txn):
+            sql = (
+                "SELECT event_id FROM current_state_events"
+                " WHERE room_id = ? AND type = ? AND state_key = ?"
+            )
+
+            args = (room_id, event_type, state_key)
+            txn.execute(sql, args)
+            results = txn.fetchall()
+            return [r[0] for r in results]
+        event_ids = yield self.runInteraction("get_current_state_for_key", f)
+        events = yield self._get_events(event_ids, get_prev_content=False)
+        defer.returnValue(events)
+
 
 def _make_group_id(clock):
     return str(int(clock.time_msec())) + random_string(5)
diff --git a/synapse/util/lrucache.py b/synapse/util/lrucache.py
index 96163c90f1..cacd7e45fa 100644
--- a/synapse/util/lrucache.py
+++ b/synapse/util/lrucache.py
@@ -20,7 +20,6 @@ import threading
 
 class LruCache(object):
     """Least-recently-used cache."""
-    # TODO(mjark) Add mutex for linked list for thread safety.
     def __init__(self, max_size):
         cache = {}
         list_root = []
@@ -106,6 +105,12 @@ class LruCache(object):
                 return default
 
         @synchronized
+        def cache_clear():
+            list_root[NEXT] = list_root
+            list_root[PREV] = list_root
+            cache.clear()
+
+        @synchronized
         def cache_len():
             return len(cache)
 
@@ -120,6 +125,7 @@ class LruCache(object):
         self.pop = cache_pop
         self.len = cache_len
         self.contains = cache_contains
+        self.clear = cache_clear
 
     def __getitem__(self, key):
         result = self.get(key, self.sentinel)