diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index f39f556c20..edc3624fed 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate((room_id,))
+ self.get_unread_message_count_for_user.invalidate_many((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 504babaa7e..18297cf3b8 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -411,7 +411,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
_get_if_maybe_push_in_range_for_user_txn,
)
- def add_push_actions_to_staging(self, event_id, user_id_actions):
+ async def add_push_actions_to_staging(self, event_id, user_id_actions):
"""Add the push actions for the event to the push action staging area.
Args:
@@ -457,7 +457,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 6f2e0d15cc..0c9c02afa1 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -53,6 +53,47 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+ EventTypes.Topic,
+ EventTypes.Name,
+ EventTypes.RoomAvatar,
+ EventTypes.Tombstone,
+}
+
+
+def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+ # Exclude rejected and soft-failed events.
+ if context.rejected or event.internal_metadata.is_soft_failed():
+ return False
+
+ # Exclude notices.
+ if (
+ not event.is_state()
+ and event.type == EventTypes.Message
+ and event.content.get("msgtype") == "m.notice"
+ ):
+ return False
+
+ # Exclude edits.
+ relates_to = event.content.get("m.relates_to", {})
+ if relates_to.get("rel_type") == RelationTypes.REPLACE:
+ return False
+
+ # Mark events that have a non-empty string body as unread.
+ body = event.content.get("body")
+ if isinstance(body, str) and body:
+ return True
+
+ # Mark some state events as unread.
+ if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+ return True
+
+ # Mark encrypted events as unread.
+ if not event.is_state() and event.type == EventTypes.Encrypted:
+ return True
+
+ return False
+
def encode_json(json_object):
"""
@@ -196,6 +237,10 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
+ self.store.get_unread_message_count_for_user.invalidate_many(
+ (event.room_id,),
+ )
+
for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
@@ -817,8 +862,9 @@ class PersistEventsStore:
"contains_url": (
"url" in event.content and isinstance(event.content["url"], str)
),
+ "count_as_unread": should_count_as_unread(event, context),
}
- for event, _ in events_and_contexts
+ for event, context in events_and_contexts
],
)
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index e812c67078..b03b259636 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
+from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import (
+ Cache,
+ _CacheContext,
+ cached,
+ cachedInlineCallbacks,
+)
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
+ @cached(tree=True, cache_context=True)
+ async def get_unread_message_count_for_user(
+ self, room_id: str, user_id: str, cache_context: _CacheContext,
+ ) -> int:
+ """Retrieve the count of unread messages for the given room and user.
+
+ Args:
+ room_id: The ID of the room to count unread messages in.
+ user_id: The ID of the user to count unread messages for.
+
+ Returns:
+ The number of unread messages for the given user in the given room.
+ """
+ with Measure(self._clock, "get_unread_message_count_for_user"):
+ last_read_event_id = await self.get_last_receipt_event_id_for_user(
+ user_id=user_id,
+ room_id=room_id,
+ receipt_type="m.read",
+ on_invalidate=cache_context.invalidate,
+ )
+
+ return await self.db.runInteraction(
+ "get_unread_message_count_for_user",
+ self._get_unread_message_count_for_user_txn,
+ user_id,
+ room_id,
+ last_read_event_id,
+ )
+
+ def _get_unread_message_count_for_user_txn(
+ self,
+ txn: Cursor,
+ user_id: str,
+ room_id: str,
+ last_read_event_id: Optional[str],
+ ) -> int:
+ if last_read_event_id:
+ # Get the stream ordering for the last read event.
+ stream_ordering = self.db.simple_select_one_onecol_txn(
+ txn=txn,
+ table="events",
+ keyvalues={"room_id": room_id, "event_id": last_read_event_id},
+ retcol="stream_ordering",
+ )
+ else:
+ # If there's no read receipt for that room, it probably means the user hasn't
+ # opened it yet, in which case use the stream ID of their join event.
+ # We can't just set it to 0 otherwise messages from other local users from
+ # before this user joined will be counted as well.
+ txn.execute(
+ """
+ SELECT stream_ordering FROM local_current_membership
+ LEFT JOIN events USING (event_id, room_id)
+ WHERE membership = 'join'
+ AND user_id = ?
+ AND room_id = ?
+ """,
+ (user_id, room_id),
+ )
+ row = txn.fetchone()
+
+ if row is None:
+ return 0
+
+ stream_ordering = row[0]
+
+ # Count the messages that qualify as unread after the stream ordering we've just
+ # retrieved.
+ sql = """
+ SELECT COUNT(*) FROM events
+ WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
+ """
+
+ txn.execute(sql, (user_id, room_id, stream_ordering))
+ row = txn.fetchone()
+
+ return row[0] if row else 0
+
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
index 6546569139..b53fe35c33 100644
--- a/synapse/storage/data_stores/main/purge_events.py
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -62,6 +62,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# event_json
# event_push_actions
# event_reference_hashes
+ # event_relations
# event_search
# event_to_state_groups
# events
@@ -209,6 +210,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_edges",
"event_forward_extremities",
"event_reference_hashes",
+ "event_relations",
"event_search",
"rejections",
):
diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
new file mode 100644
index 0000000000..531b532c73
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Store a boolean value in the events table for whether the event should be counted in
+-- the unread_count property of sync responses.
+ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 3be20c866a..ce8757a400 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,11 +49,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E
from synapse.storage.types import Connection, Cursor
from synapse.types import Collection
-logger = logging.getLogger(__name__)
-
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
+logger = logging.getLogger(__name__)
+
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
@@ -233,7 +233,7 @@ class LoggingTransaction:
try:
return func(sql, *args)
except Exception as e:
- logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+ sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
secs = time.time() - start
@@ -419,7 +419,7 @@ class Database(object):
except self.engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
- logger.warning(
+ transaction_logger.warning(
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
)
if i < N:
@@ -427,18 +427,20 @@ class Database(object):
try:
conn.rollback()
except self.engine.module.Error as e1:
- logger.warning("[TXN EROLL] {%s} %s", name, e1)
+ transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
continue
raise
except self.engine.module.DatabaseError as e:
if self.engine.is_deadlock(e):
- logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ transaction_logger.warning(
+ "[TXN DEADLOCK] {%s} %d/%d", name, i, N
+ )
if i < N:
i += 1
try:
conn.rollback()
except self.engine.module.Error as e1:
- logger.warning(
+ transaction_logger.warning(
"[TXN EROLL] {%s} %s", name, e1,
)
continue
@@ -478,7 +480,7 @@ class Database(object):
# [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
cursor.close()
except Exception as e:
- logger.debug("[TXN FAIL] {%s} %s", name, e)
+ transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = monotonic_time()
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 78fbdcdee8..4a164834d9 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.events import FrozenEvent
+from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -192,12 +192,11 @@ class EventsPersistenceStorage(object):
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
- @defer.inlineCallbacks
- def persist_events(
+ async def persist_events(
self,
- events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ):
+ ) -> int:
"""
Write events to the database
Args:
@@ -207,7 +206,7 @@ class EventsPersistenceStorage(object):
which might update the current state etc.
Returns:
- Deferred[int]: the stream ordering of the latest persisted event
+ the stream ordering of the latest persisted event
"""
partitioned = {}
for event, ctx in events_and_contexts:
@@ -223,22 +222,19 @@ class EventsPersistenceStorage(object):
for room_id in partitioned:
self._maybe_start_persisting(room_id)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
- max_persisted_id = yield self.main_store.get_current_events_token()
-
- return max_persisted_id
+ return self.main_store.get_current_events_token()
- @defer.inlineCallbacks
- def persist_event(
- self, event: FrozenEvent, context: EventContext, backfilled: bool = False
- ):
+ async def persist_event(
+ self, event: EventBase, context: EventContext, backfilled: bool = False
+ ) -> Tuple[int, int]:
"""
Returns:
- Deferred[Tuple[int, int]]: the stream ordering of ``event``,
- and the stream ordering of the latest persisted event
+ The stream ordering of `event`, and the stream ordering of the
+ latest persisted event
"""
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled
@@ -246,9 +242,9 @@ class EventsPersistenceStorage(object):
self._maybe_start_persisting(event.room_id)
- yield make_deferred_yieldable(deferred)
+ await make_deferred_yieldable(deferred)
- max_persisted_id = yield self.main_store.get_current_events_token()
+ max_persisted_id = self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id: str):
@@ -262,7 +258,7 @@ class EventsPersistenceStorage(object):
async def _persist_events(
self,
- events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
):
"""Calculates the change to current state and forward extremities, and
@@ -439,7 +435,7 @@ class EventsPersistenceStorage(object):
async def _calculate_new_extremities(
self,
room_id: str,
- event_contexts: List[Tuple[FrozenEvent, EventContext]],
+ event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str],
):
"""Calculates the new forward extremities for a room given events to
@@ -497,7 +493,7 @@ class EventsPersistenceStorage(object):
async def _get_new_state_after_events(
self,
room_id: str,
- events_context: List[Tuple[FrozenEvent, EventContext]],
+ events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Iterable[str],
new_latest_event_ids: Iterable[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
async def _is_server_still_joined(
self,
room_id: str,
- ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
+ ev_ctx_rm: List[Tuple[EventBase, EventContext]],
delta: DeltaState,
current_state: Optional[StateMap[str]],
potentially_left_users: Set[str],
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index fdc0abf5cf..79d9f06e2e 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -15,8 +15,7 @@
import itertools
import logging
-
-from twisted.internet import defer
+from typing import Set
logger = logging.getLogger(__name__)
@@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
def __init__(self, hs, stores):
self.stores = stores
- @defer.inlineCallbacks
- def purge_room(self, room_id: str):
+ async def purge_room(self, room_id: str):
"""Deletes all record of a room
"""
- state_groups_to_delete = yield self.stores.main.purge_room(room_id)
- yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
+ state_groups_to_delete = await self.stores.main.purge_room(room_id)
+ await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
- @defer.inlineCallbacks
- def purge_history(self, room_id, token, delete_local_events):
+ async def purge_history(
+ self, room_id: str, token: str, delete_local_events: bool
+ ) -> None:
"""Deletes room history before a certain point
Args:
- room_id (str):
+ room_id: The room ID
- token (str): A topological token to delete events before
+ token: A topological token to delete events before
- delete_local_events (bool):
+ delete_local_events:
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
"""
- state_groups = yield self.stores.main.purge_history(
+ state_groups = await self.stores.main.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] finding state groups that can be deleted")
- sg_to_delete = yield self._find_unreferenced_groups(state_groups)
+ sg_to_delete = await self._find_unreferenced_groups(state_groups)
- yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
+ await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
- @defer.inlineCallbacks
- def _find_unreferenced_groups(self, state_groups):
+ async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
"""Used when purging history to figure out which state groups can be
deleted.
Args:
- state_groups (set[int]): Set of state groups referenced by events
+ state_groups: Set of state groups referenced by events
that are going to be deleted.
Returns:
- Deferred[set[int]] The set of state groups that can be deleted.
+ The set of state groups that can be deleted.
"""
# Graph of state group -> previous group
graph = {}
@@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
current_search = set(itertools.islice(next_to_search, 100))
next_to_search -= current_search
- referenced = yield self.stores.main.get_referenced_state_groups(
+ referenced = await self.stores.main.get_referenced_state_groups(
current_search
)
referenced_groups |= referenced
@@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
# groups that are referenced.
current_search -= referenced
- edges = yield self.stores.state.get_previous_state_groups(current_search)
+ edges = await self.stores.state.get_previous_state_groups(current_search)
prevs = set(edges.values())
# We don't bother re-handling groups we've already seen
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index dc568476f4..49ee9c9a74 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,13 +14,12 @@
# limitations under the License.
import logging
-from typing import Iterable, List, TypeVar
+from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar
import attr
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
from synapse.types import StateMap
logger = logging.getLogger(__name__)
@@ -34,16 +33,16 @@ class StateFilter(object):
"""A filter used when querying for state.
Attributes:
- types (dict[str, set[str]|None]): Map from type to set of state keys (or
- None). This specifies which state_keys for the given type to fetch
- from the DB. If None then all events with that type are fetched. If
- the set is empty then no events with that type are fetched.
- include_others (bool): Whether to fetch events with types that do not
+ types: Map from type to set of state keys (or None). This specifies
+ which state_keys for the given type to fetch from the DB. If None
+ then all events with that type are fetched. If the set is empty
+ then no events with that type are fetched.
+ include_others: Whether to fetch events with types that do not
appear in `types`.
"""
- types = attr.ib()
- include_others = attr.ib(default=False)
+ types = attr.ib(type=Dict[str, Optional[Set[str]]])
+ include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
@@ -52,36 +51,35 @@ class StateFilter(object):
self.types = {k: v for k, v in self.types.items() if v is not None}
@staticmethod
- def all():
+ def all() -> "StateFilter":
"""Creates a filter that fetches everything.
Returns:
- StateFilter
+ The new state filter.
"""
return StateFilter(types={}, include_others=True)
@staticmethod
- def none():
+ def none() -> "StateFilter":
"""Creates a filter that fetches nothing.
Returns:
- StateFilter
+ The new state filter.
"""
return StateFilter(types={}, include_others=False)
@staticmethod
- def from_types(types):
+ def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
"""Creates a filter that only fetches the given types
Args:
- types (Iterable[tuple[str, str|None]]): A list of type and state
- keys to fetch. A state_key of None fetches everything for
- that type
+ types: A list of type and state keys to fetch. A state_key of None
+ fetches everything for that type
Returns:
- StateFilter
+ The new state filter.
"""
- type_dict = {}
+ type_dict = {} # type: Dict[str, Optional[Set[str]]]
for typ, s in types:
if typ in type_dict:
if type_dict[typ] is None:
@@ -91,24 +89,24 @@ class StateFilter(object):
type_dict[typ] = None
continue
- type_dict.setdefault(typ, set()).add(s)
+ type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict)
@staticmethod
- def from_lazy_load_member_list(members):
+ def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member
events for the given users
Args:
- members (iterable[str]): Set of user IDs
+ members: Set of user IDs
Returns:
- StateFilter
+ The new state filter
"""
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
- def return_expanded(self):
+ def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the
current one, i.e. anything that passes the current filter will pass
@@ -130,7 +128,7 @@ class StateFilter(object):
return all non-member events
Returns:
- StateFilter
+ The new state filter.
"""
if self.is_full():
@@ -167,7 +165,7 @@ class StateFilter(object):
include_others=True,
)
- def make_sql_filter_clause(self):
+ def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause.
For example:
@@ -179,13 +177,12 @@ class StateFilter(object):
Returns:
- tuple[str, list]: The SQL string (may be empty) and arguments. An
- empty SQL string is returned when the filter matches everything
- (i.e. is "full").
+ The SQL string (may be empty) and arguments. An empty SQL string is
+ returned when the filter matches everything (i.e. is "full").
"""
where_clause = ""
- where_args = []
+ where_args = [] # type: List[str]
if self.is_full():
return where_clause, where_args
@@ -221,7 +218,7 @@ class StateFilter(object):
return where_clause, where_args
- def max_entries_returned(self):
+ def max_entries_returned(self) -> Optional[int]:
"""Returns the maximum number of entries this filter will return if
known, otherwise returns None.
@@ -260,33 +257,33 @@ class StateFilter(object):
return filtered_state
- def is_full(self):
+ def is_full(self) -> bool:
"""Whether this filter fetches everything or not
Returns:
- bool
+ True if the filter fetches everything.
"""
return self.include_others and not self.types
- def has_wildcards(self):
+ def has_wildcards(self) -> bool:
"""Whether the filter includes wildcards or is attempting to fetch
specific state.
Returns:
- bool
+ True if the filter includes wildcards.
"""
return self.include_others or any(
state_keys is None for state_keys in self.types.values()
)
- def concrete_types(self):
+ def concrete_types(self) -> List[Tuple[str, str]]:
"""Returns a list of concrete type/state_keys (i.e. not None) that
will be fetched. This will be a complete list if `has_wildcards`
returns False, but otherwise will be a subset (or even empty).
Returns:
- list[tuple[str,str]]
+ A list of type/state_keys tuples.
"""
return [
(t, s)
@@ -295,7 +292,7 @@ class StateFilter(object):
for s in state_keys
]
- def get_member_split(self):
+ def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching
against non member state.
@@ -307,7 +304,7 @@ class StateFilter(object):
state caches).
Returns:
- tuple[StateFilter, StateFilter]: The member and non member filters
+ The member and non member filters
"""
if EventTypes.Member in self.types:
@@ -340,6 +337,9 @@ class StateGroupStorage(object):
"""Given a state group try to return a previous group and a delta between
the old and the new.
+ Args:
+ state_group: The state group used to retrieve state deltas.
+
Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids)
@@ -347,55 +347,59 @@ class StateGroupStorage(object):
return self.stores.state.get_state_group_delta(state_group)
- @defer.inlineCallbacks
- def get_state_groups_ids(self, _room_id, event_ids):
+ async def get_state_groups_ids(
+ self, _room_id: str, event_ids: Iterable[str]
+ ) -> Dict[int, StateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
- _room_id (str): id of the room for these events
- event_ids (iterable[str]): ids of the events
+ _room_id: id of the room for these events
+ event_ids: ids of the events
Returns:
- Deferred[dict[int, StateMap[str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
return {}
- event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+ event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
- group_to_state = yield self.stores.state._get_state_for_groups(groups)
+ group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state
- @defer.inlineCallbacks
- def get_state_ids_for_group(self, state_group):
+ async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group
Args:
- state_group (int)
+ state_group: A state group for which we want to get the state IDs.
Returns:
- Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
+ Resolves to a map of (type, state_key) -> event_id
"""
- group_to_state = yield self._get_state_for_groups((state_group,))
+ group_to_state = await self._get_state_for_groups((state_group,))
return group_to_state[state_group]
- @defer.inlineCallbacks
- def get_state_groups(self, room_id, event_ids):
+ async def get_state_groups(
+ self, room_id: str, event_ids: Iterable[str]
+ ) -> Dict[int, List[EventBase]]:
""" Get the state groups for the given list of event_ids
+
+ Args:
+ room_id: ID of the room for these events.
+ event_ids: The event IDs to retrieve state for.
+
Returns:
- Deferred[dict[int, list[EventBase]]]:
- dict of state_group_id -> list of state events.
+ dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
- group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
+ group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
- state_event_map = yield self.stores.main.get_events(
+ state_event_map = await self.stores.main.get_events(
[
ev_id
for group_ids in group_to_ids.values()
@@ -423,31 +427,34 @@ class StateGroupStorage(object):
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
+
Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
- @defer.inlineCallbacks
- def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
+ async def get_state_for_events(
+ self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
+ ):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
+
Args:
- event_ids (list[string])
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ event_ids: The events to fetch the state of.
+ state_filter: The state filter used to fetch state.
+
Returns:
- deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
+ A dict of (event_id) -> (type, state_key) -> [state_events]
"""
- event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+ event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
- group_to_state = yield self.stores.state._get_state_for_groups(
+ group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)
- state_event_map = yield self.stores.main.get_events(
+ state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
@@ -463,24 +470,24 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids}
- @defer.inlineCallbacks
- def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
+ async def get_state_ids_for_events(
+ self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
+ ):
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
- event_ids(list(str)): events whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ event_ids: events whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
Returns:
- A deferred dict from event_id -> (type, state_key) -> event_id
+ A dict from event_id -> (type, state_key) -> event_id
"""
- event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
+ event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
- group_to_state = yield self.stores.state._get_state_for_groups(
+ group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)
@@ -491,36 +498,36 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids}
- @defer.inlineCallbacks
- def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
+ async def get_state_for_event(
+ self, event_id: str, state_filter: StateFilter = StateFilter.all()
+ ):
"""
Get the state dict corresponding to a particular event
Args:
- event_id(str): event whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ event_id: event whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
Returns:
- A deferred dict from (type, state_key) -> state_event
+ A dict from (type, state_key) -> state_event
"""
- state_map = yield self.get_state_for_events([event_id], state_filter)
+ state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id]
- @defer.inlineCallbacks
- def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
+ async def get_state_ids_for_event(
+ self, event_id: str, state_filter: StateFilter = StateFilter.all()
+ ):
"""
Get the state dict corresponding to a particular event
Args:
- event_id(str): event whose state should be returned
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ event_id: event whose state should be returned
+ state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
- state_map = yield self.get_state_ids_for_events([event_id], state_filter)
+ state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
def _get_state_for_groups(
@@ -530,9 +537,8 @@ class StateGroupStorage(object):
filtering by type/state_key
Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- state_filter (StateFilter): The state filter used to fetch state
+ groups: list of state groups for which we want to get the state.
+ state_filter: The state filter used to fetch state.
from the database.
Returns:
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
@@ -540,18 +546,23 @@ class StateGroupStorage(object):
return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group(
- self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ self,
+ event_id: str,
+ room_id: str,
+ prev_group: Optional[int],
+ delta_ids: Optional[dict],
+ current_state_ids: dict,
):
"""Store a new set of state, returning a newly assigned state group.
Args:
- event_id (str): The event ID for which the state was calculated
- room_id (str)
- prev_group (int|None): A previous state group for the room, optional.
- delta_ids (dict|None): The delta between state at `prev_group` and
+ event_id: The event ID for which the state was calculated.
+ room_id: ID of the room for which the state was calculated.
+ prev_group: A previous state group for the room, optional.
+ delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
- current_state_ids (dict): The state to store. Map of (type, state_key)
+ current_state_ids: The state to store. Map of (type, state_key)
to event_id.
Returns:
|