summary refs log tree commit diff
path: root/synapse/storage/roommember.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2016-12-16 10:40:10 +0000
committerErik Johnston <erik@matrix.org>2016-12-16 10:40:10 +0000
commitf5a4001bb116c468cc5e8e0ae04a1c570e2cb171 (patch)
treefce7147d9b4422f76b5cec8b53312bb34932d84f /synapse/storage/roommember.py
parentMerge pull request #1685 from matrix-org/rav/update_readme_for_tests (diff)
parentBump version and changelog (diff)
downloadsynapse-f5a4001bb116c468cc5e8e0ae04a1c570e2cb171.tar.xz
Merge branch 'release-v0.18.5' of github.com:matrix-org/synapse v0.18.5
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r--synapse/storage/roommember.py102
1 files changed, 99 insertions, 3 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 866d64e679..946d5a81cc 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -24,6 +24,7 @@ from synapse.api.constants import Membership, EventTypes
 from synapse.types import get_domain_from_id
 
 import logging
+import ujson as json
 
 logger = logging.getLogger(__name__)
 
@@ -34,7 +35,15 @@ RoomsForUser = namedtuple(
 )
 
 
+_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
+
+
 class RoomMemberStore(SQLBaseStore):
+    def __init__(self, hs):
+        super(RoomMemberStore, self).__init__(hs)
+        self.register_background_update_handler(
+            _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+        )
 
     def _store_room_members_txn(self, txn, events, backfilled):
         """Store a room member in the database.
@@ -49,6 +58,8 @@ class RoomMemberStore(SQLBaseStore):
                     "sender": event.user_id,
                     "room_id": event.room_id,
                     "membership": event.membership,
+                    "display_name": event.content.get("displayname", None),
+                    "avatar_url": event.content.get("avatar_url", None),
                 }
                 for event in events
             ]
@@ -398,7 +409,7 @@ class RoomMemberStore(SQLBaseStore):
             table="room_memberships",
             column="event_id",
             iterable=member_event_ids,
-            retcols=['user_id'],
+            retcols=['user_id', 'display_name', 'avatar_url'],
             keyvalues={
                 "membership": Membership.JOIN,
             },
@@ -406,11 +417,21 @@ class RoomMemberStore(SQLBaseStore):
             desc="_get_joined_users_from_context",
         )
 
-        users_in_room = set(row["user_id"] for row in rows)
+        users_in_room = {
+            row["user_id"]: {
+                "display_name": row["display_name"],
+                "avatar_url": row["avatar_url"],
+            }
+            for row in rows
+        }
+
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 if event.event_id in member_event_ids:
-                    users_in_room.add(event.state_key)
+                    users_in_room[event.state_key] = {
+                        "display_name": event.content.get("displayname", None),
+                        "avatar_url": event.content.get("avatar_url", None),
+                    }
 
         defer.returnValue(users_in_room)
 
@@ -448,3 +469,78 @@ class RoomMemberStore(SQLBaseStore):
                     defer.returnValue(True)
 
         defer.returnValue(False)
+
+    @defer.inlineCallbacks
+    def _background_add_membership_profile(self, progress, batch_size):
+        target_min_stream_id = progress.get(
+            "target_min_stream_id_inclusive", self._min_stream_order_on_start
+        )
+        max_stream_id = progress.get(
+            "max_stream_id_exclusive", self._stream_order_on_start + 1
+        )
+
+        INSERT_CLUMP_SIZE = 1000
+
+        def add_membership_profile_txn(txn):
+            sql = ("""
+                SELECT stream_ordering, event_id, events.room_id, content
+                FROM events
+                INNER JOIN room_memberships USING (event_id)
+                WHERE ? <= stream_ordering AND stream_ordering < ?
+                AND type = 'm.room.member'
+                ORDER BY stream_ordering DESC
+                LIMIT ?
+            """)
+
+            txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+            rows = self.cursor_to_dict(txn)
+            if not rows:
+                return 0
+
+            min_stream_id = rows[-1]["stream_ordering"]
+
+            to_update = []
+            for row in rows:
+                event_id = row["event_id"]
+                room_id = row["room_id"]
+                try:
+                    content = json.loads(row["content"])
+                except:
+                    continue
+
+                display_name = content.get("displayname", None)
+                avatar_url = content.get("avatar_url", None)
+
+                if display_name or avatar_url:
+                    to_update.append((
+                        display_name, avatar_url, event_id, room_id
+                    ))
+
+            to_update_sql = ("""
+                UPDATE room_memberships SET display_name = ?, avatar_url = ?
+                WHERE event_id = ? AND room_id = ?
+            """)
+            for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
+                clump = to_update[index:index + INSERT_CLUMP_SIZE]
+                txn.executemany(to_update_sql, clump)
+
+            progress = {
+                "target_min_stream_id_inclusive": target_min_stream_id,
+                "max_stream_id_exclusive": min_stream_id,
+            }
+
+            self._background_update_progress_txn(
+                txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
+            )
+
+            return len(rows)
+
+        result = yield self.runInteraction(
+            _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
+        )
+
+        if not result:
+            yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
+
+        defer.returnValue(result)