diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1edc96042b..1f0a39eac4 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -755,81 +755,145 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
"""
@trace
- def _claim_e2e_one_time_keys(txn):
- sql = (
- "SELECT key_id, key_json FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
- " LIMIT 1"
+ def _claim_e2e_one_time_key_simple(
+ txn, user_id: str, device_id: str, algorithm: str
+ ) -> Optional[Tuple[str, str]]:
+ """Claim OTK for device for DBs that don't support RETURNING.
+
+ Returns:
+ A tuple of key name (algorithm + key ID) and key JSON, if an
+ OTK was found.
+ """
+
+ sql = """
+ SELECT key_id, key_json FROM e2e_one_time_keys_json
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
+ LIMIT 1
+ """
+
+ txn.execute(sql, (user_id, device_id, algorithm))
+ otk_row = txn.fetchone()
+ if otk_row is None:
+ return None
+
+ key_id, key_json = otk_row
+
+ self.db_pool.simple_delete_one_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ "key_id": key_id,
+ },
)
- fallback_sql = (
- "SELECT key_id, key_json, used FROM e2e_fallback_keys_json"
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
- " LIMIT 1"
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
- result = {}
- delete = []
- used_fallbacks = []
- for user_id, device_id, algorithm in query_list:
- user_result = result.setdefault(user_id, {})
- device_result = user_result.setdefault(device_id, {})
- txn.execute(sql, (user_id, device_id, algorithm))
- otk_row = txn.fetchone()
- if otk_row is not None:
- key_id, key_json = otk_row
- device_result[algorithm + ":" + key_id] = key_json
- delete.append((user_id, device_id, algorithm, key_id))
- else:
- # no one-time key available, so see if there's a fallback
- # key
- txn.execute(fallback_sql, (user_id, device_id, algorithm))
- fallback_row = txn.fetchone()
- if fallback_row is not None:
- key_id, key_json, used = fallback_row
- device_result[algorithm + ":" + key_id] = key_json
- if not used:
- used_fallbacks.append(
- (user_id, device_id, algorithm, key_id)
- )
-
- # drop any one-time keys that were claimed
- sql = (
- "DELETE FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
- " AND key_id = ?"
+
+ return f"{algorithm}:{key_id}", key_json
+
+ @trace
+ def _claim_e2e_one_time_key_returning(
+ txn, user_id: str, device_id: str, algorithm: str
+ ) -> Optional[Tuple[str, str]]:
+ """Claim OTK for device for DBs that support RETURNING.
+
+ Returns:
+ A tuple of key name (algorithm + key ID) and key JSON, if an
+ OTK was found.
+ """
+
+ # We can use RETURNING to do the fetch and DELETE in once step.
+ sql = """
+ DELETE FROM e2e_one_time_keys_json
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
+ AND key_id IN (
+ SELECT key_id FROM e2e_one_time_keys_json
+ WHERE user_id = ? AND device_id = ? AND algorithm = ?
+ LIMIT 1
+ )
+ RETURNING key_id, key_json
+ """
+
+ txn.execute(
+ sql, (user_id, device_id, algorithm, user_id, device_id, algorithm)
)
- for user_id, device_id, algorithm, key_id in delete:
- log_kv(
- {
- "message": "Executing claim e2e_one_time_keys transaction on database."
- }
- )
- txn.execute(sql, (user_id, device_id, algorithm, key_id))
- log_kv({"message": "finished executing and invalidating cache"})
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ otk_row = txn.fetchone()
+ if otk_row is None:
+ return None
+
+ key_id, key_json = otk_row
+ return f"{algorithm}:{key_id}", key_json
+
+ results = {}
+ for user_id, device_id, algorithm in query_list:
+ if self.database_engine.supports_returning:
+ # If we support RETURNING clause we can use a single query that
+ # allows us to use autocommit mode.
+ _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning
+ db_autocommit = True
+ else:
+ _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple
+ db_autocommit = False
+
+ row = await self.db_pool.runInteraction(
+ "claim_e2e_one_time_keys",
+ _claim_e2e_one_time_key,
+ user_id,
+ device_id,
+ algorithm,
+ db_autocommit=db_autocommit,
+ )
+ if row:
+ device_results = results.setdefault(user_id, {}).setdefault(
+ device_id, {}
)
- # mark fallback keys as used
- for user_id, device_id, algorithm, key_id in used_fallbacks:
- self.db_pool.simple_update_txn(
- txn,
- "e2e_fallback_keys_json",
- {
+ device_results[row[0]] = row[1]
+ continue
+
+ # No one-time key available, so see if there's a fallback
+ # key
+ row = await self.db_pool.simple_select_one(
+ table="e2e_fallback_keys_json",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ },
+ retcols=("key_id", "key_json", "used"),
+ desc="_get_fallback_key",
+ allow_none=True,
+ )
+ if row is None:
+ continue
+
+ key_id = row["key_id"]
+ key_json = row["key_json"]
+ used = row["used"]
+
+ # Mark fallback key as used if not already.
+ if not used:
+ await self.db_pool.simple_update_one(
+ table="e2e_fallback_keys_json",
+ keyvalues={
"user_id": user_id,
"device_id": device_id,
"algorithm": algorithm,
"key_id": key_id,
},
- {"used": True},
+ updatevalues={"used": True},
+ desc="_get_fallback_key_set_used",
)
- self._invalidate_cache_and_stream(
- txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
+ await self.invalidate_cache_and_stream(
+ "get_e2e_unused_fallback_key_types", (user_id, device_id)
)
- return result
+ device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
+ device_results[f"{algorithm}:{key_id}"] = key_json
- return await self.db_pool.runInteraction(
- "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
- )
+ return results
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 44018c1c31..bddf5ef192 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -671,27 +671,97 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
- async def get_oldest_events_with_depth_in_room(self, room_id):
+ async def get_oldest_event_ids_with_depth_in_room(self, room_id) -> Dict[str, int]:
+ """Gets the oldest events(backwards extremities) in the room along with the
+ aproximate depth.
+
+ We use this function so that we can compare and see if someones current
+ depth at their current scrollback is within pagination range of the
+ event extremeties. If the current depth is close to the depth of given
+ oldest event, we can trigger a backfill.
+
+ Args:
+ room_id: Room where we want to find the oldest events
+
+ Returns:
+ Map from event_id to depth
+ """
+
+ def get_oldest_event_ids_with_depth_in_room_txn(txn, room_id):
+ # Assemble a dictionary with event_id -> depth for the oldest events
+ # we know of in the room. Backwards extremeties are the oldest
+ # events we know of in the room but we only know of them because
+ # some other event referenced them by prev_event and aren't peristed
+ # in our database yet (meaning we don't know their depth
+ # specifically). So we need to look for the aproximate depth from
+ # the events connected to the current backwards extremeties.
+ sql = """
+ SELECT b.event_id, MAX(e.depth) FROM events as e
+ /**
+ * Get the edge connections from the event_edges table
+ * so we can see whether this event's prev_events points
+ * to a backward extremity in the next join.
+ */
+ INNER JOIN event_edges as g
+ ON g.event_id = e.event_id
+ /**
+ * We find the "oldest" events in the room by looking for
+ * events connected to backwards extremeties (oldest events
+ * in the room that we know of so far).
+ */
+ INNER JOIN event_backward_extremities as b
+ ON g.prev_event_id = b.event_id
+ WHERE b.room_id = ? AND g.is_state is ?
+ GROUP BY b.event_id
+ """
+
+ txn.execute(sql, (room_id, False))
+
+ return dict(txn)
+
return await self.db_pool.runInteraction(
- "get_oldest_events_with_depth_in_room",
- self.get_oldest_events_with_depth_in_room_txn,
+ "get_oldest_event_ids_with_depth_in_room",
+ get_oldest_event_ids_with_depth_in_room_txn,
room_id,
)
- def get_oldest_events_with_depth_in_room_txn(self, txn, room_id):
- sql = (
- "SELECT b.event_id, MAX(e.depth) FROM events as e"
- " INNER JOIN event_edges as g"
- " ON g.event_id = e.event_id"
- " INNER JOIN event_backward_extremities as b"
- " ON g.prev_event_id = b.event_id"
- " WHERE b.room_id = ? AND g.is_state is ?"
- " GROUP BY b.event_id"
- )
+ async def get_insertion_event_backwards_extremities_in_room(
+ self, room_id
+ ) -> Dict[str, int]:
+ """Get the insertion events we know about that we haven't backfilled yet.
- txn.execute(sql, (room_id, False))
+ We use this function so that we can compare and see if someones current
+ depth at their current scrollback is within pagination range of the
+ insertion event. If the current depth is close to the depth of given
+ insertion event, we can trigger a backfill.
- return dict(txn)
+ Args:
+ room_id: Room where we want to find the oldest events
+
+ Returns:
+ Map from event_id to depth
+ """
+
+ def get_insertion_event_backwards_extremities_in_room_txn(txn, room_id):
+ sql = """
+ SELECT b.event_id, MAX(e.depth) FROM insertion_events as i
+ /* We only want insertion events that are also marked as backwards extremities */
+ INNER JOIN insertion_event_extremities as b USING (event_id)
+ /* Get the depth of the insertion event from the events table */
+ INNER JOIN events AS e USING (event_id)
+ WHERE b.room_id = ?
+ GROUP BY b.event_id
+ """
+
+ txn.execute(sql, (room_id,))
+
+ return dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "get_insertion_event_backwards_extremities_in_room",
+ get_insertion_event_backwards_extremities_in_room_txn,
+ room_id,
+ )
async def get_max_depth_of(self, event_ids: List[str]) -> Tuple[str, int]:
"""Returns the event ID and depth for the event that has the max depth from a set of event IDs
@@ -1041,7 +1111,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
if row[1] not in event_results:
queue.put((-row[0], row[1]))
- # Navigate up the DAG by prev_event
txn.execute(query, (event_id, False, limit - len(event_results)))
prev_event_id_results = txn.fetchall()
logger.debug(
@@ -1136,6 +1205,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
_delete_old_forward_extrem_cache_txn,
)
+ async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
+ await self.db_pool.simple_upsert(
+ table="insertion_event_extremities",
+ keyvalues={"event_id": event_id},
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ },
+ insertion_values={},
+ desc="insert_insertion_extremity",
+ lock=False,
+ )
+
async def insert_received_event_to_staging(
self, origin: str, event: EventBase
) -> None:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 86baf397fb..40b53274fb 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1845,6 +1845,18 @@ class PersistEventsStore:
},
)
+ # When we receive an event with a `chunk_id` referencing the
+ # `next_chunk_id` of the insertion event, we can remove it from the
+ # `insertion_event_extremities` table.
+ sql = """
+ DELETE FROM insertion_event_extremities WHERE event_id IN (
+ SELECT event_id FROM insertion_events
+ WHERE next_chunk_id = ?
+ )
+ """
+
+ txn.execute(sql, (chunk_id,))
+
def _handle_redaction(self, txn, redacted_event_id):
"""Handles receiving a redaction and checking whether we need to remove
any redacted relations from the database.
@@ -2101,15 +2113,17 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
+ # From the events passed in, add all of the prev events as backwards extremities.
+ # Ignore any events that are already backwards extrems or outliers.
query = (
"INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
- " SELECT 1 FROM event_backward_extremities"
- " WHERE event_id = ? AND room_id = ?"
+ " SELECT 1 FROM event_backward_extremities"
+ " WHERE event_id = ? AND room_id = ?"
" )"
" AND NOT EXISTS ("
- " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
- " AND outlier = ?"
+ " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
+ " AND outlier = ?"
" )"
)
@@ -2123,6 +2137,8 @@ class PersistEventsStore:
],
)
+ # Delete all these events that we've already fetched and now know that their
+ # prev events are the new backwards extremeties.
query = (
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 3c86adab56..375463e4e9 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -14,7 +14,6 @@
import logging
import threading
-from collections import namedtuple
from typing import (
Collection,
Container,
@@ -27,6 +26,7 @@ from typing import (
overload,
)
+import attr
from constantly import NamedConstant, Names
from typing_extensions import Literal
@@ -42,7 +42,11 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
-from synapse.logging.context import PreserveLoggingContext, current_context
+from synapse.logging.context import (
+ PreserveLoggingContext,
+ current_context,
+ make_deferred_yieldable,
+)
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
@@ -56,6 +60,8 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
@@ -74,7 +80,10 @@ EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+@attr.s(slots=True, auto_attribs=True)
+class _EventCacheEntry:
+ event: EventBase
+ redacted_event: Optional[EventBase]
class EventRedactBehaviour(Names):
@@ -161,6 +170,13 @@ class EventsWorkerStore(SQLBaseStore):
max_size=hs.config.caches.event_cache_size,
)
+ # Map from event ID to a deferred that will result in a map from event
+ # ID to cache entry. Note that the returned dict may not have the
+ # requested event in it if the event isn't in the DB.
+ self._current_event_fetches: Dict[
+ str, ObservableDeferred[Dict[str, _EventCacheEntry]]
+ ] = {}
+
self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
@@ -476,7 +492,9 @@ class EventsWorkerStore(SQLBaseStore):
return events
- async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_cache_or_db(
+ self, event_ids: Iterable[str], allow_rejected: bool = False
+ ) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -485,53 +503,107 @@ class EventsWorkerStore(SQLBaseStore):
Args:
- event_ids (Iterable[str]): The event_ids of the events to fetch
+ event_ids: The event_ids of the events to fetch
- allow_rejected (bool): Whether to include rejected events. If False,
+ allow_rejected: Whether to include rejected events. If False,
rejected events are omitted from the response.
Returns:
- Dict[str, _EventCacheEntry]:
- map from event id to result
+ map from event id to result
"""
event_entry_map = self._get_events_from_cache(
- event_ids, allow_rejected=allow_rejected
+ event_ids,
)
- missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+ missing_events_ids = {e for e in event_ids if e not in event_entry_map}
+
+ # We now look up if we're already fetching some of the events in the DB,
+ # if so we wait for those lookups to finish instead of pulling the same
+ # events out of the DB multiple times.
+ already_fetching: Dict[str, defer.Deferred] = {}
+
+ for event_id in missing_events_ids:
+ deferred = self._current_event_fetches.get(event_id)
+ if deferred is not None:
+ # We're already pulling the event out of the DB. Add the deferred
+ # to the collection of deferreds to wait on.
+ already_fetching[event_id] = deferred.observe()
+
+ missing_events_ids.difference_update(already_fetching)
if missing_events_ids:
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))
+ # Add entries to `self._current_event_fetches` for each event we're
+ # going to pull from the DB. We use a single deferred that resolves
+ # to all the events we pulled from the DB (this will result in this
+ # function returning more events than requested, but that can happen
+ # already due to `_get_events_from_db`).
+ fetching_deferred: ObservableDeferred[
+ Dict[str, _EventCacheEntry]
+ ] = ObservableDeferred(defer.Deferred())
+ for event_id in missing_events_ids:
+ self._current_event_fetches[event_id] = fetching_deferred
+
# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
- missing_events = await self._get_events_from_db(
- missing_events_ids, allow_rejected=allow_rejected
- )
+ try:
+ missing_events = await self._get_events_from_db(
+ missing_events_ids,
+ )
- event_entry_map.update(missing_events)
+ event_entry_map.update(missing_events)
+ except Exception as e:
+ with PreserveLoggingContext():
+ fetching_deferred.errback(e)
+ raise e
+ finally:
+ # Ensure that we mark these events as no longer being fetched.
+ for event_id in missing_events_ids:
+ self._current_event_fetches.pop(event_id, None)
+
+ with PreserveLoggingContext():
+ fetching_deferred.callback(missing_events)
+
+ if already_fetching:
+ # Wait for the other event requests to finish and add their results
+ # to ours.
+ results = await make_deferred_yieldable(
+ defer.gatherResults(
+ already_fetching.values(),
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
+
+ for result in results:
+ event_entry_map.update(result)
+
+ if not allow_rejected:
+ event_entry_map = {
+ event_id: entry
+ for event_id, entry in event_entry_map.items()
+ if not entry.event.rejected_reason
+ }
return event_entry_map
def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
- def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
- """Fetch events from the caches
+ def _get_events_from_cache(
+ self, events: Iterable[str], update_metrics: bool = True
+ ) -> Dict[str, _EventCacheEntry]:
+ """Fetch events from the caches.
- Args:
- events (Iterable[str]): list of event_ids to fetch
- allow_rejected (bool): Whether to return events that were rejected
- update_metrics (bool): Whether to update the cache hit ratio metrics
+ May return rejected events.
- Returns:
- dict of event_id -> _EventCacheEntry for each event_id in cache. If
- allow_rejected is `False` then there will still be an entry but it
- will be `None`
+ Args:
+ events: list of event_ids to fetch
+ update_metrics: Whether to update the cache hit ratio metrics
"""
event_map = {}
@@ -542,10 +614,7 @@ class EventsWorkerStore(SQLBaseStore):
if not ret:
continue
- if allow_rejected or not ret.event.rejected_reason:
- event_map[event_id] = ret
- else:
- event_map[event_id] = None
+ event_map[event_id] = ret
return event_map
@@ -672,23 +741,23 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
- async def _get_events_from_db(self, event_ids, allow_rejected=False):
+ async def _get_events_from_db(
+ self, event_ids: Iterable[str]
+ ) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the database.
+ May return rejected events.
+
Returned events will be added to the cache for future lookups.
Unknown events are omitted from the response.
Args:
- event_ids (Iterable[str]): The event_ids of the events to fetch
-
- allow_rejected (bool): Whether to include rejected events. If False,
- rejected events are omitted from the response.
+ event_ids: The event_ids of the events to fetch
Returns:
- Dict[str, _EventCacheEntry]:
- map from event id to result. May return extra events which
- weren't asked for.
+ map from event id to result. May return extra events which
+ weren't asked for.
"""
fetched_events = {}
events_to_fetch = event_ids
@@ -717,9 +786,6 @@ class EventsWorkerStore(SQLBaseStore):
rejected_reason = row["rejected_reason"]
- if not allow_rejected and rejected_reason:
- continue
-
# If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown.
try:
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 6ad1a0cf7f..14670c2881 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -29,7 +29,7 @@ from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import UserID
+from synapse.types import UserID, UserInfo
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -146,6 +146,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
@cached()
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
+ """Deprecated: use get_userinfo_by_id instead"""
return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
@@ -166,6 +167,33 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_id",
)
+ async def get_userinfo_by_id(self, user_id: str) -> Optional[UserInfo]:
+ """Get a UserInfo object for a user by user ID.
+
+ Note! Currently uses the cache of `get_user_by_id`. Once that deprecated method is removed,
+ this method should be cached.
+
+ Args:
+ user_id: The user to fetch user info for.
+ Returns:
+ `UserInfo` object if user found, otherwise `None`.
+ """
+ user_data = await self.get_user_by_id(user_id)
+ if not user_data:
+ return None
+ return UserInfo(
+ appservice_id=user_data["appservice_id"],
+ consent_server_notice_sent=user_data["consent_server_notice_sent"],
+ consent_version=user_data["consent_version"],
+ creation_ts=user_data["creation_ts"],
+ is_admin=bool(user_data["admin"]),
+ is_deactivated=bool(user_data["deactivated"]),
+ is_guest=bool(user_data["is_guest"]),
+ is_shadow_banned=bool(user_data["shadow_banned"]),
+ user_id=UserID.from_string(user_data["name"]),
+ user_type=user_data["user_type"],
+ )
+
async def is_trial_user(self, user_id: str) -> bool:
"""Checks if user is in the "trial" period, i.e. within the first
N days of registration defined by `mau_trial_days` config
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 68f1b40ea6..e8157ba3d4 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -629,14 +629,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# We don't update the event cache hit ratio as it completely throws off
# the hit ratio counts. After all, we don't populate the cache if we
# miss it here
- event_map = self._get_events_from_cache(
- member_event_ids, allow_rejected=False, update_metrics=False
- )
+ event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)
missing_member_event_ids = []
for event_id in member_event_ids:
ev_entry = event_map.get(event_id)
- if ev_entry:
+ if ev_entry and not ev_entry.event.rejected_reason:
if ev_entry.event.membership == Membership.JOIN:
users_in_room[ev_entry.event.state_key] = ProfileInfo(
display_name=ev_entry.event.content.get("displayname", None),
|