diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2796354a1f..4d82c4c26d 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -703,13 +703,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(max_entries=10000)
async def is_host_joined(self, room_id: str, host: str) -> bool:
+ return await self._check_host_room_membership(room_id, host, Membership.JOIN)
+
+ @cached(max_entries=10000)
+ async def is_host_invited(self, room_id: str, host: str) -> bool:
+ return await self._check_host_room_membership(room_id, host, Membership.INVITE)
+
+ async def _check_host_room_membership(
+ self, room_id: str, host: str, membership: str
+ ) -> bool:
if "%" in host or "_" in host:
raise Exception("Invalid host name")
sql = """
SELECT state_key FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (event_id)
- WHERE m.membership = 'join'
+ WHERE m.membership = ?
AND type = 'm.room.member'
AND c.room_id = ?
AND state_key LIKE ?
@@ -722,7 +731,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
like_clause = "%:" + host
rows = await self.db_pool.execute(
- "is_host_joined", None, sql, room_id, like_clause
+ "is_host_joined", None, sql, membership, room_id, like_clause
)
if not rows:
|