summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/roommember.py361
1 files changed, 182 insertions, 179 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 3e77fd3901..b9158b9896 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,7 +18,7 @@ from twisted.internet import defer
 
 from collections import namedtuple
 
-from ._base import SQLBaseStore
+from synapse.storage.events import EventsWorkerStore
 from synapse.util.async import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -48,97 +49,7 @@ ProfileInfo = namedtuple(
 _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 
 
-class RoomMemberStore(SQLBaseStore):
-    def __init__(self, db_conn, hs):
-        super(RoomMemberStore, self).__init__(db_conn, hs)
-        self.register_background_update_handler(
-            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
-        )
-
-    def _store_room_members_txn(self, txn, events, backfilled):
-        """Store a room member in the database.
-        """
-        self._simple_insert_many_txn(
-            txn,
-            table="room_memberships",
-            values=[
-                {
-                    "event_id": event.event_id,
-                    "user_id": event.state_key,
-                    "sender": event.user_id,
-                    "room_id": event.room_id,
-                    "membership": event.membership,
-                    "display_name": event.content.get("displayname", None),
-                    "avatar_url": event.content.get("avatar_url", None),
-                }
-                for event in events
-            ]
-        )
-
-        for event in events:
-            txn.call_after(
-                self._membership_stream_cache.entity_has_changed,
-                event.state_key, event.internal_metadata.stream_ordering
-            )
-            txn.call_after(
-                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
-            )
-
-            # We update the local_invites table only if the event is "current",
-            # i.e., its something that has just happened.
-            # The only current event that can also be an outlier is if its an
-            # invite that has come in across federation.
-            is_new_state = not backfilled and (
-                not event.internal_metadata.is_outlier()
-                or event.internal_metadata.is_invite_from_remote()
-            )
-            is_mine = self.hs.is_mine_id(event.state_key)
-            if is_new_state and is_mine:
-                if event.membership == Membership.INVITE:
-                    self._simple_insert_txn(
-                        txn,
-                        table="local_invites",
-                        values={
-                            "event_id": event.event_id,
-                            "invitee": event.state_key,
-                            "inviter": event.sender,
-                            "room_id": event.room_id,
-                            "stream_id": event.internal_metadata.stream_ordering,
-                        }
-                    )
-                else:
-                    sql = (
-                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
-                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-                        " AND replaced_by is NULL"
-                    )
-
-                    txn.execute(sql, (
-                        event.internal_metadata.stream_ordering,
-                        event.event_id,
-                        event.room_id,
-                        event.state_key,
-                    ))
-
-    @defer.inlineCallbacks
-    def locally_reject_invite(self, user_id, room_id):
-        sql = (
-            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
-            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
-            " AND replaced_by is NULL"
-        )
-
-        def f(txn, stream_ordering):
-            txn.execute(sql, (
-                stream_ordering,
-                True,
-                room_id,
-                user_id,
-            ))
-
-        with self._stream_id_gen.get_next() as stream_ordering:
-            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
+class RoomMemberWorkerStore(EventsWorkerStore):
     @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
     def get_hosts_in_room(self, room_id, cache_context):
         """Returns the set of all hosts currently in the room
@@ -295,89 +206,6 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(user_who_share_room)
 
-    def forget(self, user_id, room_id):
-        """Indicate that user_id wishes to discard history for room_id."""
-        def f(txn):
-            sql = (
-                "UPDATE"
-                "  room_memberships"
-                " SET"
-                "  forgotten = 1"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-            )
-            txn.execute(sql, (user_id, room_id))
-
-            txn.call_after(self.was_forgotten_at.invalidate_all)
-            txn.call_after(self.did_forget.invalidate, (user_id, room_id))
-            self._invalidate_cache_and_stream(
-                txn, self.who_forgot_in_room, (room_id,)
-            )
-        return self.runInteraction("forget_membership", f)
-
-    @cachedInlineCallbacks(num_args=2)
-    def did_forget(self, user_id, room_id):
-        """Returns whether user_id has elected to discard history for room_id.
-
-        Returns False if they have since re-joined."""
-        def f(txn):
-            sql = (
-                "SELECT"
-                "  COUNT(*)"
-                " FROM"
-                "  room_memberships"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-                " AND"
-                "  forgotten = 0"
-            )
-            txn.execute(sql, (user_id, room_id))
-            rows = txn.fetchall()
-            return rows[0][0]
-        count = yield self.runInteraction("did_forget_membership", f)
-        defer.returnValue(count == 0)
-
-    @cachedInlineCallbacks(num_args=3)
-    def was_forgotten_at(self, user_id, room_id, event_id):
-        """Returns whether user_id has elected to discard history for room_id at
-        event_id.
-
-        event_id must be a membership event."""
-        def f(txn):
-            sql = (
-                "SELECT"
-                "  forgotten"
-                " FROM"
-                "  room_memberships"
-                " WHERE"
-                "  user_id = ?"
-                " AND"
-                "  room_id = ?"
-                " AND"
-                "  event_id = ?"
-            )
-            txn.execute(sql, (user_id, room_id, event_id))
-            rows = txn.fetchall()
-            return rows[0][0]
-        forgot = yield self.runInteraction("did_forget_membership_at", f)
-        defer.returnValue(forgot == 1)
-
-    @cached()
-    def who_forgot_in_room(self, room_id):
-        return self._simple_select_list(
-            table="room_memberships",
-            retcols=("user_id", "event_id"),
-            keyvalues={
-                "room_id": room_id,
-                "forgotten": 1,
-            },
-            desc="who_forgot"
-        )
-
     def get_joined_users_from_context(self, event, context):
         state_group = context.state_group
         if not state_group:
@@ -600,6 +428,185 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(joined_hosts)
 
+    @cached(max_entries=10000, iterable=True)
+    def _get_joined_hosts_cache(self, room_id):
+        return _JoinedHostsCache(self, room_id)
+
+
+class RoomMemberStore(RoomMemberWorkerStore):
+    def __init__(self, db_conn, hs):
+        super(RoomMemberStore, self).__init__(db_conn, hs)
+        self.register_background_update_handler(
+            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+        )
+
+    def _store_room_members_txn(self, txn, events, backfilled):
+        """Store a room member in the database.
+        """
+        self._simple_insert_many_txn(
+            txn,
+            table="room_memberships",
+            values=[
+                {
+                    "event_id": event.event_id,
+                    "user_id": event.state_key,
+                    "sender": event.user_id,
+                    "room_id": event.room_id,
+                    "membership": event.membership,
+                    "display_name": event.content.get("displayname", None),
+                    "avatar_url": event.content.get("avatar_url", None),
+                }
+                for event in events
+            ]
+        )
+
+        for event in events:
+            txn.call_after(
+                self._membership_stream_cache.entity_has_changed,
+                event.state_key, event.internal_metadata.stream_ordering
+            )
+            txn.call_after(
+                self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+            )
+
+            # We update the local_invites table only if the event is "current",
+            # i.e., its something that has just happened.
+            # The only current event that can also be an outlier is if its an
+            # invite that has come in across federation.
+            is_new_state = not backfilled and (
+                not event.internal_metadata.is_outlier()
+                or event.internal_metadata.is_invite_from_remote()
+            )
+            is_mine = self.hs.is_mine_id(event.state_key)
+            if is_new_state and is_mine:
+                if event.membership == Membership.INVITE:
+                    self._simple_insert_txn(
+                        txn,
+                        table="local_invites",
+                        values={
+                            "event_id": event.event_id,
+                            "invitee": event.state_key,
+                            "inviter": event.sender,
+                            "room_id": event.room_id,
+                            "stream_id": event.internal_metadata.stream_ordering,
+                        }
+                    )
+                else:
+                    sql = (
+                        "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+                        " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+                        " AND replaced_by is NULL"
+                    )
+
+                    txn.execute(sql, (
+                        event.internal_metadata.stream_ordering,
+                        event.event_id,
+                        event.room_id,
+                        event.state_key,
+                    ))
+
+    @defer.inlineCallbacks
+    def locally_reject_invite(self, user_id, room_id):
+        sql = (
+            "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+            " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+            " AND replaced_by is NULL"
+        )
+
+        def f(txn, stream_ordering):
+            txn.execute(sql, (
+                stream_ordering,
+                True,
+                room_id,
+                user_id,
+            ))
+
+        with self._stream_id_gen.get_next() as stream_ordering:
+            yield self.runInteraction("locally_reject_invite", f, stream_ordering)
+
+    def forget(self, user_id, room_id):
+        """Indicate that user_id wishes to discard history for room_id."""
+        def f(txn):
+            sql = (
+                "UPDATE"
+                "  room_memberships"
+                " SET"
+                "  forgotten = 1"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+            )
+            txn.execute(sql, (user_id, room_id))
+
+            txn.call_after(self.was_forgotten_at.invalidate_all)
+            txn.call_after(self.did_forget.invalidate, (user_id, room_id))
+            self._invalidate_cache_and_stream(
+                txn, self.who_forgot_in_room, (room_id,)
+            )
+        return self.runInteraction("forget_membership", f)
+
+    @cachedInlineCallbacks(num_args=2)
+    def did_forget(self, user_id, room_id):
+        """Returns whether user_id has elected to discard history for room_id.
+
+        Returns False if they have since re-joined."""
+        def f(txn):
+            sql = (
+                "SELECT"
+                "  COUNT(*)"
+                " FROM"
+                "  room_memberships"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+                " AND"
+                "  forgotten = 0"
+            )
+            txn.execute(sql, (user_id, room_id))
+            rows = txn.fetchall()
+            return rows[0][0]
+        count = yield self.runInteraction("did_forget_membership", f)
+        defer.returnValue(count == 0)
+
+    @cachedInlineCallbacks(num_args=3)
+    def was_forgotten_at(self, user_id, room_id, event_id):
+        """Returns whether user_id has elected to discard history for room_id at
+        event_id.
+
+        event_id must be a membership event."""
+        def f(txn):
+            sql = (
+                "SELECT"
+                "  forgotten"
+                " FROM"
+                "  room_memberships"
+                " WHERE"
+                "  user_id = ?"
+                " AND"
+                "  room_id = ?"
+                " AND"
+                "  event_id = ?"
+            )
+            txn.execute(sql, (user_id, room_id, event_id))
+            rows = txn.fetchall()
+            return rows[0][0]
+        forgot = yield self.runInteraction("did_forget_membership_at", f)
+        defer.returnValue(forgot == 1)
+
+    @cached()
+    def who_forgot_in_room(self, room_id):
+        return self._simple_select_list(
+            table="room_memberships",
+            retcols=("user_id", "event_id"),
+            keyvalues={
+                "room_id": room_id,
+                "forgotten": 1,
+            },
+            desc="who_forgot"
+        )
+
     @defer.inlineCallbacks
     def _background_add_membership_profile(self, progress, batch_size):
         target_min_stream_id = progress.get(
@@ -675,10 +682,6 @@ class RoomMemberStore(SQLBaseStore):
 
         defer.returnValue(result)
 
-    @cached(max_entries=10000, iterable=True)
-    def _get_joined_hosts_cache(self, room_id):
-        return _JoinedHostsCache(self, room_id)
-
 
 class _JoinedHostsCache(object):
     """Cache for joined hosts in a room that is optimised to handle updates