diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index f39f556c20..edc3624fed 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate((room_id,))
+ self.get_unread_message_count_for_user.invalidate_many((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 504babaa7e..ad82838901 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -15,11 +15,10 @@
# limitations under the License.
import logging
+from typing import List
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import Database
@@ -166,8 +165,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return {"notify_count": notify_count, "highlight_count": highlight_count}
- @defer.inlineCallbacks
- def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
+ async def get_push_action_users_in_range(
+ self, min_stream_ordering, max_stream_ordering
+ ):
def f(txn):
sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
@@ -176,26 +176,28 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
- ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
+ ret = await self.db.runInteraction("get_push_action_users_in_range", f)
return ret
- @defer.inlineCallbacks
- def get_unread_push_actions_for_user_in_range_for_http(
- self, user_id, min_stream_ordering, max_stream_ordering, limit=20
- ):
+ async def get_unread_push_actions_for_user_in_range_for_http(
+ self,
+ user_id: str,
+ min_stream_ordering: int,
+ max_stream_ordering: int,
+ limit: int = 20,
+ ) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the httppusher.
Args:
- user_id (str): The user to fetch push actions for.
- min_stream_ordering(int): The exclusive lower bound on the
+ user_id: The user to fetch push actions for.
+ min_stream_ordering: The exclusive lower bound on the
stream ordering of event push actions to fetch.
- max_stream_ordering(int): The inclusive upper bound on the
+ max_stream_ordering: The inclusive upper bound on the
stream ordering of event push actions to fetch.
- limit (int): The maximum number of rows to return.
+ limit: The maximum number of rows to return.
Returns:
- A promise which resolves to a list of dicts with the keys "event_id",
- "room_id", "stream_ordering", "actions".
+ A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions".
The list will be ordered by ascending stream_ordering.
The list will have between 0~limit entries.
"""
@@ -228,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = yield self.db.runInteraction(
+ after_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@@ -256,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = yield self.db.runInteraction(
+ no_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@@ -280,23 +282,25 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# one of the subqueries may have hit the limit.
return notifs[:limit]
- @defer.inlineCallbacks
- def get_unread_push_actions_for_user_in_range_for_email(
- self, user_id, min_stream_ordering, max_stream_ordering, limit=20
- ):
+ async def get_unread_push_actions_for_user_in_range_for_email(
+ self,
+ user_id: str,
+ min_stream_ordering: int,
+ max_stream_ordering: int,
+ limit: int = 20,
+ ) -> List[dict]:
"""Get a list of the most recent unread push actions for a given user,
within the given stream ordering range. Called by the emailpusher
Args:
- user_id (str): The user to fetch push actions for.
- min_stream_ordering(int): The exclusive lower bound on the
+ user_id: The user to fetch push actions for.
+ min_stream_ordering: The exclusive lower bound on the
stream ordering of event push actions to fetch.
- max_stream_ordering(int): The inclusive upper bound on the
+ max_stream_ordering: The inclusive upper bound on the
stream ordering of event push actions to fetch.
- limit (int): The maximum number of rows to return.
+ limit: The maximum number of rows to return.
Returns:
- A promise which resolves to a list of dicts with the keys "event_id",
- "room_id", "stream_ordering", "actions", "received_ts".
+ A list of dicts with the keys "event_id", "room_id", "stream_ordering", "actions", "received_ts".
The list will be ordered by descending received_ts.
The list will have between 0~limit entries.
"""
@@ -328,7 +332,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = yield self.db.runInteraction(
+ after_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@@ -356,7 +360,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = yield self.db.runInteraction(
+ no_read_receipt = await self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@@ -411,7 +415,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
_get_if_maybe_push_in_range_for_user_txn,
)
- def add_push_actions_to_staging(self, event_id, user_id_actions):
+ async def add_push_actions_to_staging(self, event_id, user_id_actions):
"""Add the push actions for the event to the push action staging area.
Args:
@@ -457,21 +461,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
- return self.db.runInteraction(
+ return await self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
- @defer.inlineCallbacks
- def remove_push_actions_from_staging(self, event_id):
+ async def remove_push_actions_from_staging(self, event_id: str) -> None:
"""Called if we failed to persist the event to ensure that stale push
actions don't build up in the DB
-
- Args:
- event_id (str)
"""
try:
- res = yield self.db.simple_delete(
+ res = await self.db.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
@@ -606,8 +606,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end
- @defer.inlineCallbacks
- def get_time_of_last_push_action_before(self, stream_ordering):
+ async def get_time_of_last_push_action_before(self, stream_ordering):
def f(txn):
sql = (
"SELECT e.received_ts"
@@ -620,7 +619,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (stream_ordering,))
return txn.fetchone()
- result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
+ result = await self.db.runInteraction("get_time_of_last_push_action_before", f)
return result[0] if result else None
@@ -650,8 +649,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
self._start_rotate_notifs, 30 * 60 * 1000
)
- @defer.inlineCallbacks
- def get_push_actions_for_user(
+ async def get_push_actions_for_user(
self, user_id, before=None, limit=50, only_highlight=False
):
def f(txn):
@@ -682,18 +680,17 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute(sql, args)
return self.db.cursor_to_dict(txn)
- push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
+ push_actions = await self.db.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
- @defer.inlineCallbacks
- def get_latest_push_action_stream_ordering(self):
+ async def get_latest_push_action_stream_ordering(self):
def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
- result = yield self.db.runInteraction(
+ result = await self.db.runInteraction(
"get_latest_push_action_stream_ordering", f
)
return result[0] or 0
@@ -747,8 +744,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
def _start_rotate_notifs(self):
return run_as_background_process("rotate_notifs", self._rotate_notifs)
- @defer.inlineCallbacks
- def _rotate_notifs(self):
+ async def _rotate_notifs(self):
if self._doing_notif_rotation or self.stream_ordering_day_ago is None:
return
self._doing_notif_rotation = True
@@ -757,12 +753,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True:
logger.info("Rotating notifications")
- caught_up = yield self.db.runInteraction(
+ caught_up = await self.db.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
break
- yield self.hs.get_clock().sleep(self._rotate_delay)
+ await self.hs.get_clock().sleep(self._rotate_delay)
finally:
self._doing_notif_rotation = False
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 6f2e0d15cc..0c9c02afa1 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -53,6 +53,47 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
+STATE_EVENT_TYPES_TO_MARK_UNREAD = {
+ EventTypes.Topic,
+ EventTypes.Name,
+ EventTypes.RoomAvatar,
+ EventTypes.Tombstone,
+}
+
+
+def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
+ # Exclude rejected and soft-failed events.
+ if context.rejected or event.internal_metadata.is_soft_failed():
+ return False
+
+ # Exclude notices.
+ if (
+ not event.is_state()
+ and event.type == EventTypes.Message
+ and event.content.get("msgtype") == "m.notice"
+ ):
+ return False
+
+ # Exclude edits.
+ relates_to = event.content.get("m.relates_to", {})
+ if relates_to.get("rel_type") == RelationTypes.REPLACE:
+ return False
+
+ # Mark events that have a non-empty string body as unread.
+ body = event.content.get("body")
+ if isinstance(body, str) and body:
+ return True
+
+ # Mark some state events as unread.
+ if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
+ return True
+
+ # Mark encrypted events as unread.
+ if not event.is_state() and event.type == EventTypes.Encrypted:
+ return True
+
+ return False
+
def encode_json(json_object):
"""
@@ -196,6 +237,10 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
+ self.store.get_unread_message_count_for_user.invalidate_many(
+ (event.room_id,),
+ )
+
for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
@@ -817,8 +862,9 @@ class PersistEventsStore:
"contains_url": (
"url" in event.content and isinstance(event.content["url"], str)
),
+ "count_as_unread": should_count_as_unread(event, context),
}
- for event, _ in events_and_contexts
+ for event, context in events_and_contexts
],
)
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index e812c67078..b03b259636 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
+from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
-from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import (
+ Cache,
+ _CacheContext,
+ cached,
+ cachedInlineCallbacks,
+)
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
+ @cached(tree=True, cache_context=True)
+ async def get_unread_message_count_for_user(
+ self, room_id: str, user_id: str, cache_context: _CacheContext,
+ ) -> int:
+ """Retrieve the count of unread messages for the given room and user.
+
+ Args:
+ room_id: The ID of the room to count unread messages in.
+ user_id: The ID of the user to count unread messages for.
+
+ Returns:
+ The number of unread messages for the given user in the given room.
+ """
+ with Measure(self._clock, "get_unread_message_count_for_user"):
+ last_read_event_id = await self.get_last_receipt_event_id_for_user(
+ user_id=user_id,
+ room_id=room_id,
+ receipt_type="m.read",
+ on_invalidate=cache_context.invalidate,
+ )
+
+ return await self.db.runInteraction(
+ "get_unread_message_count_for_user",
+ self._get_unread_message_count_for_user_txn,
+ user_id,
+ room_id,
+ last_read_event_id,
+ )
+
+ def _get_unread_message_count_for_user_txn(
+ self,
+ txn: Cursor,
+ user_id: str,
+ room_id: str,
+ last_read_event_id: Optional[str],
+ ) -> int:
+ if last_read_event_id:
+ # Get the stream ordering for the last read event.
+ stream_ordering = self.db.simple_select_one_onecol_txn(
+ txn=txn,
+ table="events",
+ keyvalues={"room_id": room_id, "event_id": last_read_event_id},
+ retcol="stream_ordering",
+ )
+ else:
+ # If there's no read receipt for that room, it probably means the user hasn't
+ # opened it yet, in which case use the stream ID of their join event.
+ # We can't just set it to 0 otherwise messages from other local users from
+ # before this user joined will be counted as well.
+ txn.execute(
+ """
+ SELECT stream_ordering FROM local_current_membership
+ LEFT JOIN events USING (event_id, room_id)
+ WHERE membership = 'join'
+ AND user_id = ?
+ AND room_id = ?
+ """,
+ (user_id, room_id),
+ )
+ row = txn.fetchone()
+
+ if row is None:
+ return 0
+
+ stream_ordering = row[0]
+
+ # Count the messages that qualify as unread after the stream ordering we've just
+ # retrieved.
+ sql = """
+ SELECT COUNT(*) FROM events
+ WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
+ """
+
+ txn.execute(sql, (user_id, room_id, stream_ordering))
+ row = txn.fetchone()
+
+ return row[0] if row else 0
+
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/purge_events.py b/synapse/storage/data_stores/main/purge_events.py
index 6546569139..b53fe35c33 100644
--- a/synapse/storage/data_stores/main/purge_events.py
+++ b/synapse/storage/data_stores/main/purge_events.py
@@ -62,6 +62,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# event_json
# event_push_actions
# event_reference_hashes
+ # event_relations
# event_search
# event_to_state_groups
# events
@@ -209,6 +210,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_edges",
"event_forward_extremities",
"event_reference_hashes",
+ "event_relations",
"event_search",
"rejections",
):
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index d2e1e36e7f..ab48052cdc 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -23,8 +23,6 @@ from typing import Any, Dict, List, Optional, Tuple
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
@@ -32,7 +30,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.search import SearchStore
from synapse.storage.database import Database, LoggingTransaction
from synapse.types import ThirdPartyInstanceID
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -192,8 +190,7 @@ class RoomWorkerStore(SQLBaseStore):
return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
- @defer.inlineCallbacks
- def get_largest_public_rooms(
+ async def get_largest_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
@@ -330,10 +327,10 @@ class RoomWorkerStore(SQLBaseStore):
return results
- ret_val = yield self.db.runInteraction(
+ ret_val = await self.db.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
- defer.returnValue(ret_val)
+ return ret_val
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
@@ -509,8 +506,8 @@ class RoomWorkerStore(SQLBaseStore):
"get_rooms_paginate", _get_rooms_paginate_txn,
)
- @cachedInlineCallbacks(max_entries=10000)
- def get_ratelimit_for_user(self, user_id):
+ @cached(max_entries=10000)
+ async def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
user
@@ -522,7 +519,7 @@ class RoomWorkerStore(SQLBaseStore):
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
- row = yield self.db.simple_select_one(
+ row = await self.db.simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
@@ -538,8 +535,8 @@ class RoomWorkerStore(SQLBaseStore):
else:
return None
- @cachedInlineCallbacks()
- def get_retention_policy_for_room(self, room_id):
+ @cached()
+ async def get_retention_policy_for_room(self, room_id):
"""Get the retention policy for a given room.
If no retention policy has been found for this room, returns a policy defined
@@ -566,19 +563,17 @@ class RoomWorkerStore(SQLBaseStore):
return self.db.cursor_to_dict(txn)
- ret = yield self.db.runInteraction(
+ ret = await self.db.runInteraction(
"get_retention_policy_for_room", get_retention_policy_for_room_txn,
)
# If we don't know this room ID, ret will be None, in this case return the default
# policy.
if not ret:
- defer.returnValue(
- {
- "min_lifetime": self.config.retention_default_min_lifetime,
- "max_lifetime": self.config.retention_default_max_lifetime,
- }
- )
+ return {
+ "min_lifetime": self.config.retention_default_min_lifetime,
+ "max_lifetime": self.config.retention_default_max_lifetime,
+ }
row = ret[0]
@@ -592,7 +587,7 @@ class RoomWorkerStore(SQLBaseStore):
if row["max_lifetime"] is None:
row["max_lifetime"] = self.config.retention_default_max_lifetime
- defer.returnValue(row)
+ return row
def get_media_mxcs_in_room(self, room_id):
"""Retrieves all the local and remote media MXC URIs in a given room
@@ -881,8 +876,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_add_rooms_room_version_column,
)
- @defer.inlineCallbacks
- def _background_insert_retention(self, progress, batch_size):
+ async def _background_insert_retention(self, progress, batch_size):
"""Retrieves a list of all rooms within a range and inserts an entry for each of
them into the room_retention table.
NULLs the property's columns if missing from the retention event in the room's
@@ -940,14 +934,14 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
else:
return False
- end = yield self.db.runInteraction(
+ end = await self.db.runInteraction(
"insert_room_retention", _background_insert_retention_txn,
)
if end:
- yield self.db.updates._end_background_update("insert_room_retention")
+ await self.db.updates._end_background_update("insert_room_retention")
- defer.returnValue(batch_size)
+ return batch_size
async def _background_add_rooms_room_version_column(
self, progress: dict, batch_size: int
@@ -1096,8 +1090,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
lock=False,
)
- @defer.inlineCallbacks
- def store_room(
+ async def store_room(
self,
room_id: str,
room_creator_user_id: str,
@@ -1140,7 +1133,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
- yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
+ await self.db.runInteraction("store_room_txn", store_room_txn, next_id)
except Exception as e:
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
@@ -1165,8 +1158,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
lock=False,
)
- @defer.inlineCallbacks
- def set_room_is_public(self, room_id, is_public):
+ async def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
self.db.simple_update_one_txn(
txn,
@@ -1206,13 +1198,12 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
- yield self.db.runInteraction(
+ await self.db.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
self.hs.get_notifier().on_new_replication_data()
- @defer.inlineCallbacks
- def set_room_is_public_appservice(
+ async def set_room_is_public_appservice(
self, room_id, appservice_id, network_id, is_public
):
"""Edit the appservice/network specific public room list.
@@ -1287,7 +1278,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
with self._public_room_id_gen.get_next() as next_id:
- yield self.db.runInteraction(
+ await self.db.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
next_id,
@@ -1327,52 +1318,47 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
- @defer.inlineCallbacks
- def block_room(self, room_id, user_id):
+ async def block_room(self, room_id: str, user_id: str) -> None:
"""Marks the room as blocked. Can be called multiple times.
Args:
- room_id (str): Room to block
- user_id (str): Who blocked it
-
- Returns:
- Deferred
+ room_id: Room to block
+ user_id: Who blocked it
"""
- yield self.db.simple_upsert(
+ await self.db.simple_upsert(
table="blocked_rooms",
keyvalues={"room_id": room_id},
values={},
insertion_values={"user_id": user_id},
desc="block_room",
)
- yield self.db.runInteraction(
+ await self.db.runInteraction(
"block_room_invalidation",
self._invalidate_cache_and_stream,
self.is_room_blocked,
(room_id,),
)
- @defer.inlineCallbacks
- def get_rooms_for_retention_period_in_range(
- self, min_ms, max_ms, include_null=False
- ):
+ async def get_rooms_for_retention_period_in_range(
+ self, min_ms: Optional[int], max_ms: Optional[int], include_null: bool = False
+ ) -> Dict[str, dict]:
"""Retrieves all of the rooms within the given retention range.
Optionally includes the rooms which don't have a retention policy.
Args:
- min_ms (int|None): Duration in milliseconds that define the lower limit of
+ min_ms: Duration in milliseconds that define the lower limit of
the range to handle (exclusive). If None, doesn't set a lower limit.
- max_ms (int|None): Duration in milliseconds that define the upper limit of
+ max_ms: Duration in milliseconds that define the upper limit of
the range to handle (inclusive). If None, doesn't set an upper limit.
- include_null (bool): Whether to include rooms which retention policy is NULL
+ include_null: Whether to include rooms which retention policy is NULL
in the returned set.
Returns:
- dict[str, dict]: The rooms within this range, along with their retention
- policy. The key is "room_id", and maps to a dict describing the retention
- policy associated with this room ID. The keys for this nested dict are
- "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
"""
def get_rooms_for_retention_period_in_range_txn(txn):
@@ -1431,9 +1417,9 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
return rooms_dict
- rooms = yield self.db.runInteraction(
+ rooms = await self.db.runInteraction(
"get_rooms_for_retention_period_in_range",
get_rooms_for_retention_period_in_range_txn,
)
- defer.returnValue(rooms)
+ return rooms
diff --git a/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
new file mode 100644
index 0000000000..531b532c73
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/58/12unread_messages.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Store a boolean value in the events table for whether the event should be counted in
+-- the unread_count property of sync responses.
+ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index bb38a04ede..a360699408 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -16,12 +16,12 @@
import collections.abc
import logging
from collections import namedtuple
-
-from twisted.internet import defer
+from typing import Iterable, Optional, Set
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
@@ -108,28 +108,27 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = await self.get_create_event_for_room(room_id)
return create_event.content.get("room_version", "1")
- @defer.inlineCallbacks
- def get_room_predecessor(self, room_id):
+ async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Args:
- room_id (str)
+ room_id: The room ID.
Returns:
- Deferred[dict|None]: A dictionary containing the structure of the predecessor
- field from the room's create event. The structure is subject to other servers,
- but it is expected to be:
- * room_id (str): The room ID of the predecessor room
- * event_id (str): The ID of the tombstone event in the predecessor room
+ A dictionary containing the structure of the predecessor
+ field from the room's create event. The structure is subject to other servers,
+ but it is expected to be:
+ * room_id (str): The room ID of the predecessor room
+ * event_id (str): The ID of the tombstone event in the predecessor room
- None if a predecessor key is not found, or is not a dictionary.
+ None if a predecessor key is not found, or is not a dictionary.
Raises:
NotFoundError if the given room is unknown
"""
# Retrieve the room's create event
- create_event = yield self.get_create_event_for_room(room_id)
+ create_event = await self.get_create_event_for_room(room_id)
# Retrieve the predecessor key of the create event
predecessor = create_event.content.get("predecessor", None)
@@ -140,20 +139,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return predecessor
- @defer.inlineCallbacks
- def get_create_event_for_room(self, room_id):
+ async def get_create_event_for_room(self, room_id: str) -> EventBase:
"""Get the create state event for a room.
Args:
- room_id (str)
+ room_id: The room ID.
Returns:
- Deferred[EventBase]: The room creation event.
+ The room creation event.
Raises:
NotFoundError if the room is unknown
"""
- state_ids = yield self.get_current_state_ids(room_id)
+ state_ids = await self.get_current_state_ids(room_id)
create_id = state_ids.get((EventTypes.Create, ""))
# If we can't find the create event, assume we've hit a dead end
@@ -161,7 +159,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return
- create_event = yield self.get_event(create_id)
+ create_event = await self.get_event(create_id)
return create_event
@cached(max_entries=100000, iterable=True)
@@ -237,18 +235,17 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
- @defer.inlineCallbacks
- def get_canonical_alias_for_room(self, room_id):
+ async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
"""Get canonical alias for room, if any
Args:
- room_id (str)
+ room_id: The room ID
Returns:
- Deferred[str|None]: The canonical alias, if any
+ The canonical alias, if any
"""
- state = yield self.get_filtered_current_state_ids(
+ state = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
)
@@ -256,7 +253,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event_id:
return
- event = yield self.get_event(event_id, allow_none=True)
+ event = await self.get_event(event_id, allow_none=True)
if not event:
return
@@ -292,19 +289,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {row["event_id"]: row["state_group"] for row in rows}
- @defer.inlineCallbacks
- def get_referenced_state_groups(self, state_groups):
+ async def get_referenced_state_groups(
+ self, state_groups: Iterable[int]
+ ) -> Set[int]:
"""Check if the state groups are referenced by events.
Args:
- state_groups (Iterable[int])
+ state_groups
Returns:
- Deferred[set[int]]: The subset of state groups that are
- referenced.
+ The subset of state groups that are referenced.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 380c1ec7da..922400a7c3 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -16,8 +16,8 @@
import logging
from itertools import chain
+from typing import Tuple
-from twisted.internet import defer
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
@@ -97,13 +97,12 @@ class StatsStore(StateDeltasStore):
"""
return (ts // self.stats_bucket_size) * self.stats_bucket_size
- @defer.inlineCallbacks
- def _populate_stats_process_users(self, progress, batch_size):
+ async def _populate_stats_process_users(self, progress, batch_size):
"""
This is a background update which regenerates statistics for users.
"""
if not self.stats_enabled:
- yield self.db.updates._end_background_update("populate_stats_process_users")
+ await self.db.updates._end_background_update("populate_stats_process_users")
return 1
last_user_id = progress.get("last_user_id", "")
@@ -118,20 +117,20 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_user_id, batch_size))
return [r for r, in txn]
- users_to_work_on = yield self.db.runInteraction(
+ users_to_work_on = await self.db.runInteraction(
"_populate_stats_process_users", _get_next_batch
)
# No more rooms -- complete the transaction.
if not users_to_work_on:
- yield self.db.updates._end_background_update("populate_stats_process_users")
+ await self.db.updates._end_background_update("populate_stats_process_users")
return 1
for user_id in users_to_work_on:
- yield self._calculate_and_set_initial_state_for_user(user_id)
+ await self._calculate_and_set_initial_state_for_user(user_id)
progress["last_user_id"] = user_id
- yield self.db.runInteraction(
+ await self.db.runInteraction(
"populate_stats_process_users",
self.db.updates._background_update_progress_txn,
"populate_stats_process_users",
@@ -140,13 +139,12 @@ class StatsStore(StateDeltasStore):
return len(users_to_work_on)
- @defer.inlineCallbacks
- def _populate_stats_process_rooms(self, progress, batch_size):
+ async def _populate_stats_process_rooms(self, progress, batch_size):
"""
This is a background update which regenerates statistics for rooms.
"""
if not self.stats_enabled:
- yield self.db.updates._end_background_update("populate_stats_process_rooms")
+ await self.db.updates._end_background_update("populate_stats_process_rooms")
return 1
last_room_id = progress.get("last_room_id", "")
@@ -161,20 +159,20 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, (last_room_id, batch_size))
return [r for r, in txn]
- rooms_to_work_on = yield self.db.runInteraction(
+ rooms_to_work_on = await self.db.runInteraction(
"populate_stats_rooms_get_batch", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
- yield self.db.updates._end_background_update("populate_stats_process_rooms")
+ await self.db.updates._end_background_update("populate_stats_process_rooms")
return 1
for room_id in rooms_to_work_on:
- yield self._calculate_and_set_initial_state_for_room(room_id)
+ await self._calculate_and_set_initial_state_for_room(room_id)
progress["last_room_id"] = room_id
- yield self.db.runInteraction(
+ await self.db.runInteraction(
"_populate_stats_process_rooms",
self.db.updates._background_update_progress_txn,
"populate_stats_process_rooms",
@@ -696,16 +694,16 @@ class StatsStore(StateDeltasStore):
return room_deltas, user_deltas
- @defer.inlineCallbacks
- def _calculate_and_set_initial_state_for_room(self, room_id):
+ async def _calculate_and_set_initial_state_for_room(
+ self, room_id: str
+ ) -> Tuple[dict, dict, int]:
"""Calculate and insert an entry into room_stats_current.
Args:
- room_id (str)
+ room_id: The room ID under calculation.
Returns:
- Deferred[tuple[dict, dict, int]]: A tuple of room state, membership
- counts and stream position.
+ A tuple of room state, membership counts and stream position.
"""
def _fetch_current_state_stats(txn):
@@ -767,11 +765,11 @@ class StatsStore(StateDeltasStore):
current_state_events_count,
users_in_room,
pos,
- ) = yield self.db.runInteraction(
+ ) = await self.db.runInteraction(
"get_initial_state_for_room", _fetch_current_state_stats
)
- state_event_map = yield self.get_events(event_ids, get_prev_content=False)
+ state_event_map = await self.get_events(event_ids, get_prev_content=False)
room_state = {
"join_rules": None,
@@ -806,11 +804,11 @@ class StatsStore(StateDeltasStore):
event.content.get("m.federate", True) is True
)
- yield self.update_room_state(room_id, room_state)
+ await self.update_room_state(room_id, room_state)
local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
- yield self.update_stats_delta(
+ await self.update_stats_delta(
ts=self.clock.time_msec(),
stats_type="room",
stats_id=room_id,
@@ -826,8 +824,7 @@ class StatsStore(StateDeltasStore):
},
)
- @defer.inlineCallbacks
- def _calculate_and_set_initial_state_for_user(self, user_id):
+ async def _calculate_and_set_initial_state_for_user(self, user_id):
def _calculate_and_set_initial_state_for_user_txn(txn):
pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
@@ -842,12 +839,12 @@ class StatsStore(StateDeltasStore):
(count,) = txn.fetchone()
return count, pos
- joined_rooms, pos = yield self.db.runInteraction(
+ joined_rooms, pos = await self.db.runInteraction(
"calculate_and_set_initial_state_for_user",
_calculate_and_set_initial_state_for_user_txn,
)
- yield self.update_stats_delta(
+ await self.update_stats_delta(
ts=self.clock.time_msec(),
stats_type="user",
stats_id=user_id,
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
index 128c09a2cf..7dada7f75f 100644
--- a/synapse/storage/data_stores/state/store.py
+++ b/synapse/storage/data_stores/state/store.py
@@ -139,10 +139,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"get_state_group_delta", _get_state_group_delta_txn
)
- @defer.inlineCallbacks
- def _get_state_groups_from_groups(
+ async def _get_state_groups_from_groups(
self, groups: List[int], state_filter: StateFilter
- ):
+ ) -> Dict[int, StateMap[str]]:
"""Returns the state groups for a given set of groups from the
database, filtering on types of state events.
@@ -151,13 +150,13 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+ Dict of state group to state map.
"""
results = {}
chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
- res = yield self.db.runInteraction(
+ res = await self.db.runInteraction(
"_get_state_groups_from_groups",
self._get_state_groups_from_groups_txn,
chunk,
@@ -206,10 +205,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return state_filter.filter_state(state_dict_ids), not missing_types
- @defer.inlineCallbacks
- def _get_state_for_groups(
+ async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
- ):
+ ) -> Dict[int, StateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@@ -219,7 +217,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+ Dict of state group to state map.
"""
member_filter, non_member_filter = state_filter.get_member_split()
@@ -228,14 +226,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
(
non_member_state,
incomplete_groups_nm,
- ) = yield self._get_state_for_groups_using_cache(
+ ) = self._get_state_for_groups_using_cache(
groups, self._state_group_cache, state_filter=non_member_filter
)
- (
- member_state,
- incomplete_groups_m,
- ) = yield self._get_state_for_groups_using_cache(
+ (member_state, incomplete_groups_m,) = self._get_state_for_groups_using_cache(
groups, self._state_group_members_cache, state_filter=member_filter
)
@@ -256,7 +251,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
# Help the cache hit ratio by expanding the filter a bit
db_state_filter = state_filter.return_expanded()
- group_to_state_dict = yield self._get_state_groups_from_groups(
+ group_to_state_dict = await self._get_state_groups_from_groups(
list(incomplete_groups), state_filter=db_state_filter
)
@@ -576,19 +571,19 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
((sg,) for sg in state_groups_to_delete),
)
- @defer.inlineCallbacks
- def get_previous_state_groups(self, state_groups):
+ async def get_previous_state_groups(
+ self, state_groups: Iterable[int]
+ ) -> Dict[int, int]:
"""Fetch the previous groups of the given state groups.
Args:
- state_groups (Iterable[int])
+ state_groups
Returns:
- Deferred[dict[int, int]]: mapping from state group to previous
- state group.
+ A mapping from state group to previous state group.
"""
- rows = yield self.db.simple_select_many_batch(
+ rows = await self.db.simple_select_many_batch(
table="state_group_edges",
column="prev_state_group",
iterable=state_groups,
|