diff --git a/synapse/__init__.py b/synapse/__init__.py
index 1bed6393bd..fbfd506a43 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -21,6 +21,7 @@ import os
import sys
from synapse.util.rust import check_rust_lib_up_to_date
+from synapse.util.stringutils import strtobool
# Check that we're not running on an unsupported Python version.
if sys.version_info < (3, 7):
@@ -28,25 +29,22 @@ if sys.version_info < (3, 7):
sys.exit(1)
# Allow using the asyncio reactor via env var.
-if bool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", False)):
- try:
- from incremental import Version
+if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")):
+ from incremental import Version
- import twisted
+ import twisted
- # We need a bugfix that is included in Twisted 21.2.0:
- # https://twistedmatrix.com/trac/ticket/9787
- if twisted.version < Version("Twisted", 21, 2, 0):
- print("Using asyncio reactor requires Twisted>=21.2.0")
- sys.exit(1)
+ # We need a bugfix that is included in Twisted 21.2.0:
+ # https://twistedmatrix.com/trac/ticket/9787
+ if twisted.version < Version("Twisted", 21, 2, 0):
+ print("Using asyncio reactor requires Twisted>=21.2.0")
+ sys.exit(1)
- import asyncio
+ import asyncio
- from twisted.internet import asyncioreactor
+ from twisted.internet import asyncioreactor
- asyncioreactor.install(asyncio.get_event_loop())
- except ImportError:
- pass
+ asyncioreactor.install(asyncio.get_event_loop())
# Twisted and canonicaljson will fail to import when this file is executed to
# get the __version__ during a fresh install. That's OK and subsequent calls to
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index 5fa599e70e..d850e54e17 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
)
+from synapse.storage.databases.main.relations import RelationsWorkerStore
from synapse.storage.databases.main.room import RoomBackgroundUpdateStore
from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore
from synapse.storage.databases.main.search import SearchBackgroundUpdateStore
@@ -206,6 +207,7 @@ class Store(
PusherWorkerStore,
PresenceBackgroundUpdateStore,
ReceiptsBackgroundUpdateStore,
+ RelationsWorkerStore,
):
def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]:
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index c031903b1a..44c5ffc6a5 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -31,6 +31,9 @@ MAX_ALIAS_LENGTH = 255
# the maximum length for a user id is 255 characters
MAX_USERID_LENGTH = 255
+# Constant value used for the pseudo-thread which is the main timeline.
+MAIN_TIMELINE: Final = "main"
+
class Membership:
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index f7f46f8d80..cc31cf8cc7 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -84,6 +84,8 @@ ROOM_EVENT_FILTER_SCHEMA = {
"contains_url": {"type": "boolean"},
"lazy_load_members": {"type": "boolean"},
"include_redundant_members": {"type": "boolean"},
+ "unread_thread_notifications": {"type": "boolean"},
+ "org.matrix.msc3773.unread_thread_notifications": {"type": "boolean"},
# Include or exclude events with the provided labels.
# cf https://github.com/matrix-org/matrix-doc/pull/2326
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
@@ -240,6 +242,9 @@ class FilterCollection:
def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members
+ def unread_thread_notifications(self) -> bool:
+ return self._room_timeline_filter.unread_thread_notifications
+
async def filter_presence(
self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
@@ -304,6 +309,16 @@ class Filter:
self.include_redundant_members = filter_json.get(
"include_redundant_members", False
)
+ self.unread_thread_notifications: bool = filter_json.get(
+ "unread_thread_notifications", False
+ )
+ if (
+ not self.unread_thread_notifications
+ and hs.config.experimental.msc3773_enabled
+ ):
+ self.unread_thread_notifications = filter_json.get(
+ "org.matrix.msc3773.unread_thread_notifications", False
+ )
self.types = filter_json.get("types", None)
self.not_types = filter_json.get("not_types", [])
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 5e3825fca6..dc49840f73 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -65,6 +65,7 @@ from synapse.rest.client import (
push_rule,
read_marker,
receipts,
+ relations,
room,
room_batch,
room_keys,
@@ -308,6 +309,7 @@ class GenericWorkerServer(HomeServer):
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
room.register_servlets(self, resource, is_worker=True)
+ relations.register_servlets(self, resource)
room.register_deprecated_servlets(self, resource)
initial_sync.register_servlets(self, resource)
room_batch.register_servlets(self, resource)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 0963fb3bb4..fbac4375b0 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -120,7 +120,11 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
try:
- response = await self.get_json(uri, {"access_token": service.hs_token})
+ response = await self.get_json(
+ uri,
+ {"access_token": service.hs_token},
+ headers={"Authorization": f"Bearer {service.hs_token}"},
+ )
if response is not None: # just an empty json object
return True
except CodeMessageException as e:
@@ -140,7 +144,11 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
try:
- response = await self.get_json(uri, {"access_token": service.hs_token})
+ response = await self.get_json(
+ uri,
+ {"access_token": service.hs_token},
+ headers={"Authorization": f"Bearer {service.hs_token}"},
+ )
if response is not None: # just an empty json object
return True
except CodeMessageException as e:
@@ -181,7 +189,9 @@ class ApplicationServiceApi(SimpleHttpClient):
**fields,
b"access_token": service.hs_token,
}
- response = await self.get_json(uri, args=args)
+ response = await self.get_json(
+ uri, args=args, headers={"Authorization": f"Bearer {service.hs_token}"}
+ )
if not isinstance(response, list):
logger.warning(
"query_3pe to %s returned an invalid response %r", uri, response
@@ -217,7 +227,11 @@ class ApplicationServiceApi(SimpleHttpClient):
urllib.parse.quote(protocol),
)
try:
- info = await self.get_json(uri, {"access_token": service.hs_token})
+ info = await self.get_json(
+ uri,
+ {"access_token": service.hs_token},
+ headers={"Authorization": f"Bearer {service.hs_token}"},
+ )
if not _is_valid_3pe_metadata(info):
logger.warning(
@@ -313,6 +327,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri=uri,
json_body=body,
args={"access_token": service.hs_token},
+ headers={"Authorization": f"Bearer {service.hs_token}"},
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 31834fb27d..f44655516e 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -95,16 +95,8 @@ class ExperimentalConfig(Config):
# MSC2815 (allow room moderators to view redacted event content)
self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False)
- # MSC3786 (Add a default push rule to ignore m.room.server_acl events)
- self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False)
-
- # MSC3771: Thread read receipts
- self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False)
- # MSC3772: A push rule for mutual relations.
- self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
-
- # MSC3715: dir param on /relations.
- self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
+ # MSC3773: Thread notifications
+ self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
# MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
deleted file mode 100644
index baa051fdd4..0000000000
--- a/synapse/config/groups.py
+++ /dev/null
@@ -1,27 +0,0 @@
-# Copyright 2017 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from typing import Any
-
-from synapse.types import JsonDict
-
-from ._base import Config
-
-
-class GroupsConfig(Config):
- section = "groups"
-
- def read_config(self, config: JsonDict, **kwargs: Any) -> None:
- self.enable_group_creation = config.get("enable_group_creation", False)
- self.group_creation_prefix = config.get("group_creation_prefix", "")
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 6c1f78f8df..b62b3b9205 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -326,6 +326,8 @@ def setup_logging(
logBeginner: The Twisted logBeginner to use.
"""
+ from twisted.internet import reactor
+
log_config_path = (
config.worker.worker_log_config
if use_worker_options
@@ -348,3 +350,4 @@ def setup_logging(
)
logging.info("Server hostname: %s", config.server.server_name)
logging.info("Instance name: %s", hs.get_instance_name())
+ logging.info("Twisted reactor: %s", type(reactor).__name__)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 4dca711cd2..b220ab43fc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1294,7 +1294,7 @@ class FederationClient(FederationBase):
return resp[1]
async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
- """Attempts to send a knock event to given a list of servers. Iterates
+ """Attempts to send a knock event to a given list of servers. Iterates
through the list until one attempt succeeds.
Doing so will cause the remote server to add the event to the graph,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 907940e19e..28097664b4 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -824,7 +824,14 @@ class FederationServer(FederationBase):
context, self._room_prejoin_state_types
)
)
- return {"knock_state_events": stripped_room_state}
+ return {
+ "knock_room_state": stripped_room_state,
+ # Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
+ # Thus, we also populate a 'knock_state_events' with the same content to
+ # support old instances.
+ # See https://github.com/matrix-org/synapse/issues/14088.
+ "knock_state_events": stripped_room_state,
+ }
async def _on_send_membership_event(
self, origin: str, content: JsonDict, membership_type: str, room_id: str
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index a6cb3ba58f..774ecd81b6 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -353,21 +353,25 @@ class FederationSender(AbstractFederationSender):
last_token = await self.store.get_federation_out_pos("events")
(
next_token,
- events,
event_to_received_ts,
- ) = await self.store.get_all_new_events_stream(
+ ) = await self.store.get_all_new_event_ids_stream(
last_token, self._last_poked_id, limit=100
)
+ event_ids = event_to_received_ts.keys()
+ event_entries = await self.store.get_unredacted_events_from_cache_or_db(
+ event_ids
+ )
+
logger.debug(
"Handling %i -> %i: %i events to send (current id %i)",
last_token,
next_token,
- len(events),
+ len(event_entries),
self._last_poked_id,
)
- if not events and next_token >= self._last_poked_id:
+ if not event_entries and next_token >= self._last_poked_id:
logger.debug("All events processed")
break
@@ -508,8 +512,14 @@ class FederationSender(AbstractFederationSender):
await handle_event(event)
events_by_room: Dict[str, List[EventBase]] = {}
- for event in events:
- events_by_room.setdefault(event.room_id, []).append(event)
+
+ for event_id in event_ids:
+ # `event_entries` is unsorted, so we have to iterate over `event_ids`
+ # to ensure the events are in the right order
+ event_cache = event_entries.get(event_id)
+ if event_cache:
+ event = event_cache.event
+ events_by_room.setdefault(event.room_id, []).append(event)
await make_deferred_yieldable(
defer.gatherResults(
@@ -524,9 +534,10 @@ class FederationSender(AbstractFederationSender):
logger.debug("Successfully handled up to %i", next_token)
await self.store.update_federation_out_pos("events", next_token)
- if events:
+ if event_entries:
now = self.clock.time_msec()
- ts = event_to_received_ts[events[-1].event_id]
+ last_id = next(reversed(event_ids))
+ ts = event_to_received_ts[last_id]
assert ts is not None
synapse.metrics.event_processing_lag.labels(
@@ -536,7 +547,7 @@ class FederationSender(AbstractFederationSender):
"federation_sender"
).set(ts)
- events_processed_counter.inc(len(events))
+ events_processed_counter.inc(len(event_entries))
event_processing_loop_room_count.labels("federation_sender").inc(
len(events_by_room)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 32074b8ca6..cd39d4d111 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -45,6 +45,7 @@ from synapse.federation.units import Transaction
from synapse.http.matrixfederationclient import ByteParser
from synapse.http.types import QueryParams
from synapse.types import JsonDict
+from synapse.util import ExceptionBundle
logger = logging.getLogger(__name__)
@@ -926,8 +927,7 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
return len(data)
def finish(self) -> SendJoinResponse:
- for c in self._coros:
- c.close()
+ _close_coros(self._coros)
if self._response.event_dict:
self._response.event = make_event_from_dict(
@@ -970,6 +970,27 @@ class _StateParser(ByteParser[StateRequestResponse]):
return len(data)
def finish(self) -> StateRequestResponse:
- for c in self._coros:
- c.close()
+ _close_coros(self._coros)
return self._response
+
+
+def _close_coros(coros: Iterable[Generator[None, bytes, None]]) -> None:
+ """Close each of the given coroutines.
+
+ Always calls .close() on each coroutine, even if doing so raises an exception.
+ Any exceptions raised are aggregated into an ExceptionBundle.
+
+ :raises ExceptionBundle: if at least one coroutine fails to close.
+ """
+ exceptions = []
+ for c in coros:
+ try:
+ c.close()
+ except Exception as e:
+ exceptions.append(e)
+
+ if exceptions:
+ # raise from the first exception so that the traceback has slightly more context
+ raise ExceptionBundle(
+ f"There were {len(exceptions)} errors closing coroutines", exceptions
+ ) from exceptions[0]
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index 6bb4659c4c..6f11138b57 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -489,7 +489,7 @@ class FederationV2InviteServlet(BaseFederationServerServlet):
room_version = content["room_version"]
event = content["event"]
- invite_room_state = content["invite_room_state"]
+ invite_room_state = content.get("invite_room_state", [])
# Synapse expects invite_room_state to be in unsigned, as it is in v1
# API
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 0478448b47..fc21d58001 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
self,
user: UserID,
from_key: int,
- limit: Optional[int],
+ limit: int,
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 203b62e015..66f5b8d108 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -109,10 +109,13 @@ class ApplicationServicesHandler:
last_token = await self.store.get_appservice_last_pos()
(
upper_bound,
- events,
event_to_received_ts,
- ) = await self.store.get_all_new_events_stream(
- last_token, self.current_max, limit=100, get_prev_content=True
+ ) = await self.store.get_all_new_event_ids_stream(
+ last_token, self.current_max, limit=100
+ )
+
+ events = await self.store.get_events_as_list(
+ event_to_received_ts.keys(), get_prev_content=True
)
events_by_room: Dict[str, List[EventBase]] = {}
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 986ffed3d5..44e70c6c3c 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -781,15 +781,27 @@ class FederationHandler:
# Send the signed event back to the room, and potentially receive some
# further information about the room in the form of partial state events
- stripped_room_state = await self.federation_client.send_knock(
- target_hosts, event
- )
+ knock_response = await self.federation_client.send_knock(target_hosts, event)
# Store any stripped room state events in the "unsigned" key of the event.
# This is a bit of a hack and is cribbing off of invites. Basically we
# store the room state here and retrieve it again when this event appears
# in the invitee's sync stream. It is stripped out for all other local users.
- event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
+ stripped_room_state = (
+ knock_response.get("knock_room_state")
+ # Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
+ # Thus, we also check for a 'knock_state_events' to support old instances.
+ # See https://github.com/matrix-org/synapse/issues/14088.
+ or knock_response.get("knock_state_events")
+ )
+
+ if stripped_room_state is None:
+ raise KeyError(
+ "Missing 'knock_room_state' (or legacy 'knock_state_events') field in "
+ "send_knock response"
+ )
+
+ event.unsigned["knock_room_state"] = stripped_room_state
context = EventContext.for_outlier(self._storage_controllers)
stream_id = await self._federation_event_handler.persist_events_and_notify(
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 778d8869b3..f382961099 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -414,7 +414,9 @@ class FederationEventHandler:
# First, precalculate the joined hosts so that the federation sender doesn't
# need to.
- await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
+ await self._event_creation_handler.cache_joined_hosts_for_events(
+ [(event, context)]
+ )
await self._check_for_soft_fail(event, context=context, origin=origin)
await self._run_push_actions_and_persist_event(event, context)
@@ -2240,8 +2242,8 @@ class FederationEventHandler:
event_pos = PersistedEventPosition(
self._instance_name, event.internal_metadata.stream_ordering
)
- await self._notifier.on_new_room_event(
- event, event_pos, max_stream_token, extra_users=extra_users
+ await self._notifier.on_new_room_events(
+ [(event, event_pos)], max_stream_token, extra_users=extra_users
)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 860c82c110..9c335e6863 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -57,13 +57,7 @@ class InitialSyncHandler:
self.validator = EventValidator()
self.snapshot_cache: ResponseCache[
Tuple[
- str,
- Optional[StreamToken],
- Optional[StreamToken],
- str,
- Optional[int],
- bool,
- bool,
+ str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
]
] = ResponseCache(hs.get_clock(), "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
@@ -154,11 +148,6 @@ class InitialSyncHandler:
public_room_ids = await self.store.get_public_room_ids()
- if pagin_config.limit is not None:
- limit = pagin_config.limit
- else:
- limit = 10
-
serializer_options = SerializeEventConfig(as_client_event=as_client_event)
async def handle_room(event: RoomsForUser) -> None:
@@ -210,7 +199,7 @@ class InitialSyncHandler:
run_in_background(
self.store.get_recent_events_for_room,
event.room_id,
- limit=limit,
+ limit=pagin_config.limit,
end_token=room_end_token,
),
deferred_room_state,
@@ -360,15 +349,11 @@ class InitialSyncHandler:
member_event_id
)
- limit = pagin_config.limit if pagin_config else None
- if limit is None:
- limit = 10
-
leave_position = await self.store.get_position_for_event(member_event_id)
stream_token = leave_position.to_room_stream_token()
messages, token = await self.store.get_recent_events_for_room(
- room_id, limit=limit, end_token=stream_token
+ room_id, limit=pagin_config.limit, end_token=stream_token
)
messages = await filter_events_for_client(
@@ -420,10 +405,6 @@ class InitialSyncHandler:
now_token = self.hs.get_event_sources().get_current_token()
- limit = pagin_config.limit if pagin_config else None
- if limit is None:
- limit = 10
-
room_members = [
m
for m in current_state.values()
@@ -467,7 +448,7 @@ class InitialSyncHandler:
run_in_background(
self.store.get_recent_events_for_room,
room_id,
- limit=limit,
+ limit=pagin_config.limit,
end_token=now_token.room_key,
),
),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 00e7645ba5..4e55ebba0b 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1390,7 +1390,7 @@ class EventCreationHandler:
extra_users=extra_users,
),
run_in_background(
- self.cache_joined_hosts_for_event, event, context
+ self.cache_joined_hosts_for_events, events_and_context
).addErrback(
log_failure, "cache_joined_hosts_for_event failed"
),
@@ -1491,62 +1491,65 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
- async def cache_joined_hosts_for_event(
- self, event: EventBase, context: EventContext
+ async def cache_joined_hosts_for_events(
+ self, events_and_context: List[Tuple[EventBase, EventContext]]
) -> None:
- """Precalculate the joined hosts at the event, when using Redis, so that
+ """Precalculate the joined hosts at each of the given events, when using Redis, so that
external federation senders don't have to recalculate it themselves.
"""
- if not self._external_cache.is_enabled():
- return
+ for event, _ in events_and_context:
+ if not self._external_cache.is_enabled():
+ return
- # If external cache is enabled we should always have this.
- assert self._external_cache_joined_hosts_updates is not None
+ # If external cache is enabled we should always have this.
+ assert self._external_cache_joined_hosts_updates is not None
- # We actually store two mappings, event ID -> prev state group,
- # state group -> joined hosts, which is much more space efficient
- # than event ID -> joined hosts.
- #
- # Note: We have to cache event ID -> prev state group, as we don't
- # store that in the DB.
- #
- # Note: We set the state group -> joined hosts cache if it hasn't been
- # set for a while, so that the expiry time is reset.
-
- state_entry = await self.state.resolve_state_groups_for_events(
- event.room_id, event_ids=event.prev_event_ids()
- )
+ # We actually store two mappings, event ID -> prev state group,
+ # state group -> joined hosts, which is much more space efficient
+ # than event ID -> joined hosts.
+ #
+ # Note: We have to cache event ID -> prev state group, as we don't
+ # store that in the DB.
+ #
+ # Note: We set the state group -> joined hosts cache if it hasn't been
+ # set for a while, so that the expiry time is reset.
- if state_entry.state_group:
- await self._external_cache.set(
- "event_to_prev_state_group",
- event.event_id,
- state_entry.state_group,
- expiry_ms=60 * 60 * 1000,
+ state_entry = await self.state.resolve_state_groups_for_events(
+ event.room_id, event_ids=event.prev_event_ids()
)
- if state_entry.state_group in self._external_cache_joined_hosts_updates:
- return
+ if state_entry.state_group:
+ await self._external_cache.set(
+ "event_to_prev_state_group",
+ event.event_id,
+ state_entry.state_group,
+ expiry_ms=60 * 60 * 1000,
+ )
- state = await state_entry.get_state(
- self._storage_controllers.state, StateFilter.all()
- )
- with opentracing.start_active_span("get_joined_hosts"):
- joined_hosts = await self.store.get_joined_hosts(
- event.room_id, state, state_entry
+ if state_entry.state_group in self._external_cache_joined_hosts_updates:
+ return
+
+ state = await state_entry.get_state(
+ self._storage_controllers.state, StateFilter.all()
)
+ with opentracing.start_active_span("get_joined_hosts"):
+ joined_hosts = await self.store.get_joined_hosts(
+ event.room_id, state, state_entry
+ )
- # Note that the expiry times must be larger than the expiry time in
- # _external_cache_joined_hosts_updates.
- await self._external_cache.set(
- "get_joined_hosts",
- str(state_entry.state_group),
- list(joined_hosts),
- expiry_ms=60 * 60 * 1000,
- )
+ # Note that the expiry times must be larger than the expiry time in
+ # _external_cache_joined_hosts_updates.
+ await self._external_cache.set(
+ "get_joined_hosts",
+ str(state_entry.state_group),
+ list(joined_hosts),
+ expiry_ms=60 * 60 * 1000,
+ )
- self._external_cache_joined_hosts_updates[state_entry.state_group] = None
+ self._external_cache_joined_hosts_updates[
+ state_entry.state_group
+ ] = None
async def _validate_canonical_alias(
self,
@@ -1872,6 +1875,7 @@ class EventCreationHandler:
events_and_context, backfilled=backfilled
)
+ events_and_pos = []
for event in persisted_events:
if self._ephemeral_events_enabled:
# If there's an expiry timestamp on the event, schedule its expiry.
@@ -1880,25 +1884,23 @@ class EventCreationHandler:
stream_ordering = event.internal_metadata.stream_ordering
assert stream_ordering is not None
pos = PersistedEventPosition(self._instance_name, stream_ordering)
-
- async def _notify() -> None:
- try:
- await self.notifier.on_new_room_event(
- event, pos, max_stream_token, extra_users=extra_users
- )
- except Exception:
- logger.exception(
- "Error notifying about new room event %s",
- event.event_id,
- )
-
- run_in_background(_notify)
+ events_and_pos.append((event, pos))
if event.type == EventTypes.Message:
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
+ async def _notify() -> None:
+ try:
+ await self.notifier.on_new_room_events(
+ events_and_pos, max_stream_token, extra_users=extra_users
+ )
+ except Exception:
+ logger.exception("Error notifying about new room events")
+
+ run_in_background(_notify)
+
return persisted_events[-1]
async def _maybe_kick_guest_users(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1f83bab836..a4ca9cb8b4 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -458,11 +458,6 @@ class PaginationHandler:
# `/messages` should still works with live tokens when manually provided.
assert from_token.room_key.topological is not None
- if pagin_config.limit is None:
- # This shouldn't happen as we've set a default limit before this
- # gets called.
- raise Exception("limit not set")
-
room_token = from_token.room_key
async with self.pagination_lock.read(room_id):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 4e575ffbaa..2670e561d7 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1596,7 +1596,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
self,
user: UserID,
from_key: Optional[int],
- limit: Optional[int] = None,
+ # Having a default limit doesn't match the EventSource API, but some
+ # callers do not provide it. It is unused in this class.
+ limit: int = 0,
room_ids: Optional[Collection[str]] = None,
is_guest: bool = False,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 4768a34c07..ac01582442 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -63,8 +63,6 @@ class ReceiptsHandler:
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
- self._msc3771_enabled = hs.config.experimental.msc3771_enabled
-
async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
"""Called when we receive an EDU of type m.receipt from a remote HS."""
receipts = []
@@ -96,11 +94,10 @@ class ReceiptsHandler:
# Check if these receipts apply to a thread.
thread_id = None
data = user_values.get("data", {})
- if self._msc3771_enabled and isinstance(data, dict):
- thread_id = data.get("thread_id")
- # If the thread ID is invalid, consider it missing.
- if not isinstance(thread_id, str):
- thread_id = None
+ thread_id = data.get("thread_id")
+ # If the thread ID is invalid, consider it missing.
+ if not isinstance(thread_id, str):
+ thread_id = None
receipts.append(
ReadReceipt(
@@ -260,7 +257,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
self,
user: UserID,
from_key: int,
- limit: Optional[int],
+ limit: int,
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 63bc6a7aa5..0a0c6d938e 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import enum
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
@@ -20,7 +21,8 @@ from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.logging.opentracing import trace
-from synapse.storage.databases.main.relations import _RelatedEvent
+from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
+from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
@@ -31,6 +33,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class ThreadsListInclude(str, enum.Enum):
+ """Valid values for the 'include' flag of /threads."""
+
+ all = "all"
+ participated = "participated"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
@@ -72,13 +81,10 @@ class RelationsHandler:
requester: Requester,
event_id: str,
room_id: str,
+ pagin_config: PaginationConfig,
+ include_original_event: bool,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- limit: int = 5,
- direction: str = "b",
- from_token: Optional[StreamToken] = None,
- to_token: Optional[StreamToken] = None,
- include_original_event: bool = False,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
@@ -88,14 +94,10 @@ class RelationsHandler:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
+ pagin_config: The pagination config rules to apply, if any.
+ include_original_event: Whether to include the parent event.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- limit: Only fetch the most recent `limit` events.
- direction: Whether to fetch the most recent first (`"b"`) or the
- oldest first (`"f"`).
- from_token: Fetch rows from the given token, or from the start if None.
- to_token: Fetch rows up to the given token, or up to the end if None.
- include_original_event: Whether to include the parent event.
Returns:
The pagination chunk.
@@ -123,10 +125,10 @@ class RelationsHandler:
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
- limit=limit,
- direction=direction,
- from_token=from_token,
- to_token=to_token,
+ limit=pagin_config.limit,
+ direction=pagin_config.direction,
+ from_token=pagin_config.from_token,
+ to_token=pagin_config.to_token,
)
events = await self._main_store.get_events_as_list(
@@ -162,8 +164,10 @@ class RelationsHandler:
if next_token:
return_value["next_batch"] = await next_token.to_string(self._main_store)
- if from_token:
- return_value["prev_batch"] = await from_token.to_string(self._main_store)
+ if pagin_config.from_token:
+ return_value["prev_batch"] = await pagin_config.from_token.to_string(
+ self._main_store
+ )
return return_value
@@ -483,3 +487,79 @@ class RelationsHandler:
results.setdefault(event_id, BundledAggregations()).replace = edit
return results
+
+ async def get_threads(
+ self,
+ requester: Requester,
+ room_id: str,
+ include: ThreadsListInclude,
+ limit: int = 5,
+ from_token: Optional[ThreadsNextBatch] = None,
+ ) -> JsonDict:
+ """Get related events of a event, ordered by topological ordering.
+
+ Args:
+ requester: The user requesting the relations.
+ room_id: The room the event belongs to.
+ include: One of "all" or "participated" to indicate which threads should
+ be returned.
+ limit: Only fetch the most recent `limit` events.
+ from_token: Fetch rows from the given token, or from the start if None.
+
+ Returns:
+ The pagination chunk.
+ """
+
+ user_id = requester.user.to_string()
+
+ # TODO Properly handle a user leaving a room.
+ (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+ room_id, requester, allow_departed_users=True
+ )
+
+ # Note that ignored users are not passed into get_relations_for_event
+ # below. Ignored users are handled in filter_events_for_client (and by
+ # not passing them in here we should get a better cache hit rate).
+ thread_roots, next_batch = await self._main_store.get_threads(
+ room_id=room_id, limit=limit, from_token=from_token
+ )
+
+ events = await self._main_store.get_events_as_list(thread_roots)
+
+ if include == ThreadsListInclude.participated:
+ # Pre-seed thread participation with whether the requester sent the event.
+ participated = {event.event_id: event.sender == user_id for event in events}
+ # For events the requester did not send, check the database for whether
+ # the requester sent a threaded reply.
+ participated.update(
+ await self._main_store.get_threads_participated(
+ [eid for eid, p in participated.items() if not p],
+ user_id,
+ )
+ )
+
+ # Limit the returned threads to those the user has participated in.
+ events = [event for event in events if participated[event.event_id]]
+
+ events = await filter_events_for_client(
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+
+ aggregations = await self.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+
+ now = self._clock.time_msec()
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
+
+ return_value: JsonDict = {"chunk": serialized_events}
+
+ if next_batch:
+ return_value["next_batch"] = str(next_batch)
+
+ return return_value
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 57ab05ad25..4e1aacb408 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1646,7 +1646,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
self,
user: UserID,
from_key: RoomStreamToken,
- limit: Optional[int],
+ limit: int,
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4abb9b6127..1db5d68021 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -40,7 +40,7 @@ from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
-from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.storage.databases.main.event_push_actions import RoomNotifCounts
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
@@ -128,6 +128,7 @@ class JoinedSyncResult:
ephemeral: List[JsonDict]
account_data: List[JsonDict]
unread_notifications: JsonDict
+ unread_thread_notifications: JsonDict
summary: Optional[JsonDict]
unread_count: int
@@ -1288,7 +1289,7 @@ class SyncHandler:
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
- ) -> NotifCounts:
+ ) -> RoomNotifCounts:
with Measure(self.clock, "unread_notifs_for_room_id"):
return await self.store.get_unread_event_push_actions_by_room_for_user(
@@ -1314,6 +1315,19 @@ class SyncHandler:
At the end, we transfer data from the `sync_result_builder` to a new `SyncResult`
instance to signify that the sync calculation is complete.
"""
+
+ user_id = sync_config.user.to_string()
+ app_service = self.store.get_app_service_by_user_id(user_id)
+ if app_service:
+ # We no longer support AS users using /sync directly.
+ # See https://github.com/matrix-org/matrix-doc/issues/1144
+ raise NotImplementedError()
+
+ # Note: we get the users room list *before* we get the current token, this
+ # avoids checking back in history if rooms are joined after the token is fetched.
+ token_before_rooms = self.event_sources.get_current_token()
+ mutable_joined_room_ids = set(await self.store.get_rooms_for_user(user_id))
+
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
@@ -1321,6 +1335,57 @@ class SyncHandler:
now_token = self.event_sources.get_current_token()
log_kv({"now_token": now_token})
+ # Since we fetched the users room list before the token, there's a small window
+ # during which membership events may have been persisted, so we fetch these now
+ # and modify the joined room list for any changes between the get_rooms_for_user
+ # call and the get_current_token call.
+ membership_change_events = []
+ if since_token:
+ membership_change_events = await self.store.get_membership_changes_for_user(
+ user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
+ )
+
+ mem_last_change_by_room_id: Dict[str, EventBase] = {}
+ for event in membership_change_events:
+ mem_last_change_by_room_id[event.room_id] = event
+
+ # For the latest membership event in each room found, add/remove the room ID
+ # from the joined room list accordingly. In this case we only care if the
+ # latest change is JOIN.
+
+ for room_id, event in mem_last_change_by_room_id.items():
+ assert event.internal_metadata.stream_ordering
+ if (
+ event.internal_metadata.stream_ordering
+ < token_before_rooms.room_key.stream
+ ):
+ continue
+
+ logger.info(
+ "User membership change between getting rooms and current token: %s %s %s",
+ user_id,
+ event.membership,
+ room_id,
+ )
+ # User joined a room - we have to then check the room state to ensure we
+ # respect any bans if there's a race between the join and ban events.
+ if event.membership == Membership.JOIN:
+ user_ids_in_room = await self.store.get_users_in_room(room_id)
+ if user_id in user_ids_in_room:
+ mutable_joined_room_ids.add(room_id)
+ # The user left the room, or left and was re-invited but not joined yet
+ else:
+ mutable_joined_room_ids.discard(room_id)
+
+ # Now we have our list of joined room IDs, exclude as configured and freeze
+ joined_room_ids = frozenset(
+ (
+ room_id
+ for room_id in mutable_joined_room_ids
+ if room_id not in self.rooms_to_exclude
+ )
+ )
+
logger.debug(
"Calculating sync response for %r between %s and %s",
sync_config.user,
@@ -1328,22 +1393,13 @@ class SyncHandler:
now_token,
)
- user_id = sync_config.user.to_string()
- app_service = self.store.get_app_service_by_user_id(user_id)
- if app_service:
- # We no longer support AS users using /sync directly.
- # See https://github.com/matrix-org/matrix-doc/issues/1144
- raise NotImplementedError()
- else:
- joined_room_ids = await self.get_rooms_for_user_at(
- user_id, now_token.room_key
- )
sync_result_builder = SyncResultBuilder(
sync_config,
full_state,
since_token=since_token,
now_token=now_token,
joined_room_ids=joined_room_ids,
+ membership_change_events=membership_change_events,
)
logger.debug("Fetching account data")
@@ -1824,19 +1880,12 @@ class SyncHandler:
Does not modify the `sync_result_builder`.
"""
- user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
- now_token = sync_result_builder.now_token
+ membership_change_events = sync_result_builder.membership_change_events
assert since_token
- # Get a list of membership change events that have happened to the user
- # requesting the sync.
- membership_changes = await self.store.get_membership_changes_for_user(
- user_id, since_token.room_key, now_token.room_key
- )
-
- if membership_changes:
+ if membership_change_events:
return True
stream_id = since_token.room_key.stream
@@ -1875,16 +1924,10 @@ class SyncHandler:
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
+ membership_change_events = sync_result_builder.membership_change_events
assert since_token
- # TODO: we've already called this function and ran this query in
- # _have_rooms_changed. We could keep the results in memory to avoid a
- # second query, at the cost of more complicated source code.
- membership_change_events = await self.store.get_membership_changes_for_user(
- user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
- )
-
mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
for event in membership_change_events:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
@@ -2353,6 +2396,7 @@ class SyncHandler:
ephemeral=ephemeral,
account_data=account_data_events,
unread_notifications=unread_notifications,
+ unread_thread_notifications={},
summary=summary,
unread_count=0,
)
@@ -2360,10 +2404,33 @@ class SyncHandler:
if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
- unread_notifications["notification_count"] = notifs.notify_count
- unread_notifications["highlight_count"] = notifs.highlight_count
-
- room_sync.unread_count = notifs.unread_count
+ # Notifications for the main timeline.
+ notify_count = notifs.main_timeline.notify_count
+ highlight_count = notifs.main_timeline.highlight_count
+ unread_count = notifs.main_timeline.unread_count
+
+ # Check the sync configuration.
+ if sync_config.filter_collection.unread_thread_notifications():
+ # And add info for each thread.
+ room_sync.unread_thread_notifications = {
+ thread_id: {
+ "notification_count": thread_notifs.notify_count,
+ "highlight_count": thread_notifs.highlight_count,
+ }
+ for thread_id, thread_notifs in notifs.threads.items()
+ if thread_id is not None
+ }
+
+ else:
+ # Combine the unread counts for all threads and main timeline.
+ for thread_notifs in notifs.threads.values():
+ notify_count += thread_notifs.notify_count
+ highlight_count += thread_notifs.highlight_count
+ unread_count += thread_notifs.unread_count
+
+ unread_notifications["notification_count"] = notify_count
+ unread_notifications["highlight_count"] = highlight_count
+ room_sync.unread_count = unread_count
sync_result_builder.joined.append(room_sync)
@@ -2385,60 +2452,6 @@ class SyncHandler:
else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
- async def get_rooms_for_user_at(
- self,
- user_id: str,
- room_key: RoomStreamToken,
- ) -> FrozenSet[str]:
- """Get set of joined rooms for a user at the given stream ordering.
-
- The stream ordering *must* be recent, otherwise this may throw an
- exception if older than a month. (This function is called with the
- current token, which should be perfectly fine).
-
- Args:
- user_id
- stream_ordering
-
- ReturnValue:
- Set of room_ids the user is in at given stream_ordering.
- """
- joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id)
-
- joined_room_ids = set()
-
- # We need to check that the stream ordering of the join for each room
- # is before the stream_ordering asked for. This might not be the case
- # if the user joins a room between us getting the current token and
- # calling `get_rooms_for_user_with_stream_ordering`.
- # If the membership's stream ordering is after the given stream
- # ordering, we need to go and work out if the user was in the room
- # before.
- # We also need to check whether the room should be excluded from sync
- # responses as per the homeserver config.
- for joined_room in joined_rooms:
- if joined_room.room_id in self.rooms_to_exclude:
- continue
-
- if not joined_room.event_pos.persisted_after(room_key):
- joined_room_ids.add(joined_room.room_id)
- continue
-
- logger.info("User joined room after current token: %s", joined_room.room_id)
-
- extrems = (
- await self.store.get_forward_extremities_for_room_at_stream_ordering(
- joined_room.room_id, joined_room.event_pos.stream
- )
- )
- user_ids_in_room = await self.state.get_current_user_ids_in_room(
- joined_room.room_id, extrems
- )
- if user_id in user_ids_in_room:
- joined_room_ids.add(joined_room.room_id)
-
- return frozenset(joined_room_ids)
-
def _action_has_highlight(actions: List[JsonDict]) -> bool:
for action in actions:
@@ -2535,6 +2548,7 @@ class SyncResultBuilder:
since_token: Optional[StreamToken]
now_token: StreamToken
joined_room_ids: FrozenSet[str]
+ membership_change_events: List[EventBase]
presence: List[UserPresenceState] = attr.Factory(list)
account_data: List[JsonDict] = attr.Factory(list)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index f953691669..a0ea719430 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
self,
user: UserID,
from_key: int,
- limit: Optional[int],
+ limit: int,
room_ids: Iterable[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 80acbdcf3c..dead02cd5c 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -35,6 +35,7 @@ from typing_extensions import Literal
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
+from synapse.http import redact_uri
from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.util import json_decoder
@@ -664,7 +665,13 @@ def parse_json_value_from_request(
try:
content = json_decoder.decode(content_bytes.decode("utf-8"))
except Exception as e:
- logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
+ logger.warning(
+ "Unable to parse JSON from %s %s response: %s (%s)",
+ request.method.decode("ascii", errors="replace"),
+ redact_uri(request.uri.decode("ascii", errors="replace")),
+ e,
+ content_bytes,
+ )
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON
)
diff --git a/synapse/metrics/_legacy_exposition.py b/synapse/metrics/_legacy_exposition.py
index 563d8cc2c6..1459f9d224 100644
--- a/synapse/metrics/_legacy_exposition.py
+++ b/synapse/metrics/_legacy_exposition.py
@@ -20,7 +20,7 @@ Due to the renaming of metrics in prometheus_client 0.4.0, this customised
vendoring of the code will emit both the old versions that Synapse dashboards
expect, and the newer "best practice" version of the up-to-date official client.
"""
-
+import logging
import math
import threading
from http.server import BaseHTTPRequestHandler, HTTPServer
@@ -34,6 +34,7 @@ from prometheus_client.core import Sample
from twisted.web.resource import Resource
from twisted.web.server import Request
+logger = logging.getLogger(__name__)
CONTENT_TYPE_LATEST = "text/plain; version=0.0.4; charset=utf-8"
@@ -219,11 +220,16 @@ class MetricsHandler(BaseHTTPRequestHandler):
except Exception:
self.send_error(500, "error generating metric output")
raise
- self.send_response(200)
- self.send_header("Content-Type", CONTENT_TYPE_LATEST)
- self.send_header("Content-Length", str(len(output)))
- self.end_headers()
- self.wfile.write(output)
+ try:
+ self.send_response(200)
+ self.send_header("Content-Type", CONTENT_TYPE_LATEST)
+ self.send_header("Content-Length", str(len(output)))
+ self.end_headers()
+ self.wfile.write(output)
+ except BrokenPipeError as e:
+ logger.warning(
+ "BrokenPipeError when serving metrics (%s). Did Prometheus restart?", e
+ )
def log_message(self, format: str, *args: Any) -> None:
"""Log nothing."""
diff --git a/synapse/notifier.py b/synapse/notifier.py
index c42bb8266a..26b97cf766 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -294,35 +294,31 @@ class Notifier:
"""
self._new_join_in_room_callbacks.append(cb)
- async def on_new_room_event(
+ async def on_new_room_events(
self,
- event: EventBase,
- event_pos: PersistedEventPosition,
+ events_and_pos: List[Tuple[EventBase, PersistedEventPosition]],
max_room_stream_token: RoomStreamToken,
extra_users: Optional[Collection[UserID]] = None,
) -> None:
- """Unwraps event and calls `on_new_room_event_args`."""
- await self.on_new_room_event_args(
- event_pos=event_pos,
- room_id=event.room_id,
- event_id=event.event_id,
- event_type=event.type,
- state_key=event.get("state_key"),
- membership=event.content.get("membership"),
- max_room_stream_token=max_room_stream_token,
- extra_users=extra_users or [],
- )
+ """Creates a _PendingRoomEventEntry for each of the listed events and calls
+ notify_new_room_events with the results."""
+ event_entries = []
+ for event, pos in events_and_pos:
+ entry = self.create_pending_room_event_entry(
+ pos,
+ extra_users,
+ event.room_id,
+ event.type,
+ event.get("state_key"),
+ event.content.get("membership"),
+ )
+ event_entries.append((entry, event.event_id))
+ await self.notify_new_room_events(event_entries, max_room_stream_token)
- async def on_new_room_event_args(
+ async def notify_new_room_events(
self,
- room_id: str,
- event_id: str,
- event_type: str,
- state_key: Optional[str],
- membership: Optional[str],
- event_pos: PersistedEventPosition,
+ event_entries: List[Tuple[_PendingRoomEventEntry, str]],
max_room_stream_token: RoomStreamToken,
- extra_users: Optional[Collection[UserID]] = None,
) -> None:
"""Used by handlers to inform the notifier something has happened
in the room, room event wise.
@@ -338,22 +334,33 @@ class Notifier:
until all previous events have been persisted before notifying
the client streams.
"""
- self.pending_new_room_events.append(
- _PendingRoomEventEntry(
- event_pos=event_pos,
- extra_users=extra_users or [],
- room_id=room_id,
- type=event_type,
- state_key=state_key,
- membership=membership,
- )
- )
- self._notify_pending_new_room_events(max_room_stream_token)
+ for event_entry, event_id in event_entries:
+ self.pending_new_room_events.append(event_entry)
+ await self._third_party_rules.on_new_event(event_id)
- await self._third_party_rules.on_new_event(event_id)
+ self._notify_pending_new_room_events(max_room_stream_token)
self.notify_replication()
+ def create_pending_room_event_entry(
+ self,
+ event_pos: PersistedEventPosition,
+ extra_users: Optional[Collection[UserID]],
+ room_id: str,
+ event_type: str,
+ state_key: Optional[str],
+ membership: Optional[str],
+ ) -> _PendingRoomEventEntry:
+ """Creates and returns a _PendingRoomEventEntry"""
+ return _PendingRoomEventEntry(
+ event_pos=event_pos,
+ extra_users=extra_users or [],
+ room_id=room_id,
+ type=event_type,
+ state_key=state_key,
+ membership=membership,
+ )
+
def _notify_pending_new_room_events(
self, max_room_stream_token: RoomStreamToken
) -> None:
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 998354648f..a75386f6a0 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -13,32 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
- Iterable,
List,
Mapping,
Optional,
- Set,
Tuple,
Union,
)
from prometheus_client import Counter
-from synapse.api.constants import EventTypes, Membership, RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, EventTypes, Membership, RelationTypes
from synapse.event_auth import auth_types_for_event, get_user_power_level
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.storage.databases.main.roommember import EventIdMembership
from synapse.storage.state import StateFilter
-from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator
+from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator
from synapse.util.caches import register_cache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_event_for_clients_with_state
@@ -117,9 +114,6 @@ class BulkPushRuleEvaluator:
resizable=False,
)
- # Whether to support MSC3772 is supported.
- self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled
-
async def _get_rules_for_event(
self,
event: EventBase,
@@ -200,51 +194,6 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def _get_mutual_relations(
- self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]]
- ) -> Dict[str, Set[Tuple[str, str]]]:
- """
- Fetch event metadata for events which related to the same event as the given event.
-
- If the given event has no relation information, returns an empty dictionary.
-
- Args:
- parent_id: The event ID which is targeted by relations.
- rules: The push rules which will be processed for this event.
-
- Returns:
- A dictionary of relation type to:
- A set of tuples of:
- The sender
- The event type
- """
-
- # If the experimental feature is not enabled, skip fetching relations.
- if not self._relations_match_enabled:
- return {}
-
- # Pre-filter to figure out which relation types are interesting.
- rel_types = set()
- for rule, enabled in rules:
- if not enabled:
- continue
-
- for condition in rule.conditions:
- if condition["kind"] != "org.matrix.msc3772.relation_match":
- continue
-
- # rel_type is required.
- rel_type = condition.get("rel_type")
- if rel_type:
- rel_types.add(rel_type)
-
- # If no valid rules were found, no mutual relations.
- if not rel_types:
- return {}
-
- # If any valid rules were found, fetch the mutual relations.
- return await self.store.get_mutual_event_relations(parent_id, rel_types)
-
@measure_func("action_for_event_by_user")
async def action_for_event_by_user(
self, event: EventBase, context: EventContext
@@ -276,18 +225,18 @@ class BulkPushRuleEvaluator:
sender_power_level,
) = await self._get_power_levels_and_sender_level(event, context)
+ # Find the event's thread ID.
relation = relation_from_event(event)
- # If the event does not have a relation, then cannot have any mutual
- # relations or thread ID.
- relations = {}
- thread_id = "main"
+ # If the event does not have a relation, then it cannot have a thread ID.
+ thread_id = MAIN_TIMELINE
if relation:
- relations = await self._get_mutual_relations(
- relation.parent_id,
- itertools.chain(*(r.rules() for r in rules_by_user.values())),
- )
+ # Recursively attempt to find the thread this event relates to.
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
+ else:
+ # Since the event has not yet been persisted we check whether
+ # the parent is part of a thread.
+ thread_id = await self.store.get_thread_id(relation.parent_id)
# It's possible that old room versions have non-integer power levels (floats or
# strings). Workaround this by explicitly converting to int.
@@ -301,8 +250,6 @@ class BulkPushRuleEvaluator:
room_member_count,
sender_power_level,
notification_levels,
- relations,
- self._relations_match_enabled,
)
users = rules_by_user.keys()
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 658bf373b7..edeba27a45 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -39,7 +39,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
await concurrently_execute(get_room_unread_count, joins, 10)
for notifs in room_notifs:
- if notifs.notify_count == 0:
+ # Combine the counts from all the threads.
+ notify_count = notifs.main_timeline.notify_count + sum(
+ n.notify_count for n in notifs.threads.values()
+ )
+
+ if notify_count == 0:
continue
if group_by_room:
@@ -47,7 +52,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge += 1
else:
# increment the badge count by the number of unread messages in the room
- badge += notifs.notify_count
+ badge += notify_count
return badge
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index b2522f98ca..18252a2958 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -210,15 +210,16 @@ class ReplicationDataHandler:
max_token = self.store.get_room_max_token()
event_pos = PersistedEventPosition(instance_name, token)
- await self.notifier.on_new_room_event_args(
- event_pos=event_pos,
- max_room_stream_token=max_token,
- extra_users=extra_users,
- room_id=row.data.room_id,
- event_id=row.data.event_id,
- event_type=row.data.type,
- state_key=row.data.state_key,
- membership=row.data.membership,
+ event_entry = self.notifier.create_pending_room_event_entry(
+ event_pos,
+ extra_users,
+ row.data.room_id,
+ row.data.type,
+ row.data.state_key,
+ row.data.membership,
+ )
+ await self.notifier.notify_new_room_events(
+ [(event_entry, row.data.event_id)], max_token
)
# If this event is a join, make a note of it so we have an accurate
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index ed6ce78d47..90828c95c4 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -14,18 +14,21 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from pydantic import Extra, StrictStr
from synapse.api import errors
from synapse.api.errors import NotFoundError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
- assert_params_in_dict,
- parse_json_object_from_request,
+ parse_and_validate_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns, interactive_auth_handler
+from synapse.rest.client.models import AuthenticationData
+from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -80,27 +83,29 @@ class DeleteDevicesRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
+ class PostBody(RequestBodyModel):
+ auth: Optional[AuthenticationData]
+ devices: List[StrictStr]
+
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
try:
- body = parse_json_object_from_request(request)
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
- # DELETE
+ # TODO: Can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
- body = {}
+ body = self.PostBody.parse_obj({})
else:
raise e
- assert_params_in_dict(body, ["devices"])
-
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
- body,
+ body.dict(exclude_unset=True),
"remove device(s) from your account",
# Users might call this multiple times in a row while cleaning up
# devices, allow a single UI auth session to be re-used.
@@ -108,7 +113,7 @@ class DeleteDevicesRestServlet(RestServlet):
)
await self.device_handler.delete_devices(
- requester.user.to_string(), body["devices"]
+ requester.user.to_string(), body.devices
)
return 200, {}
@@ -147,6 +152,9 @@ class DeviceRestServlet(RestServlet):
return 200, device
+ class DeleteBody(RequestBodyModel):
+ auth: Optional[AuthenticationData]
+
@interactive_auth_handler
async def on_DELETE(
self, request: SynapseRequest, device_id: str
@@ -154,20 +162,21 @@ class DeviceRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
try:
- body = parse_json_object_from_request(request)
+ body = parse_and_validate_json_object_from_request(request, self.DeleteBody)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
+ # TODO: can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
- body = {}
+ body = self.DeleteBody.parse_obj({})
else:
raise
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
- body,
+ body.dict(exclude_unset=True),
"remove a device from your account",
# Users might call this multiple times in a row while cleaning up
# devices, allow a single UI auth session to be re-used.
@@ -179,18 +188,33 @@ class DeviceRestServlet(RestServlet):
)
return 200, {}
+ class PutBody(RequestBodyModel):
+ display_name: Optional[StrictStr]
+
async def on_PUT(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- body = parse_json_object_from_request(request)
+ body = parse_and_validate_json_object_from_request(request, self.PutBody)
await self.device_handler.update_device(
- requester.user.to_string(), device_id, body
+ requester.user.to_string(), device_id, body.dict()
)
return 200, {}
+class DehydratedDeviceDataModel(RequestBodyModel):
+ """JSON blob describing a dehydrated device to be stored.
+
+ Expects other freeform fields. Use .dict() to access them.
+ """
+
+ class Config:
+ extra = Extra.allow
+
+ algorithm: StrictStr
+
+
class DehydratedDeviceServlet(RestServlet):
"""Retrieve or store a dehydrated device.
@@ -246,27 +270,19 @@ class DehydratedDeviceServlet(RestServlet):
else:
raise errors.NotFoundError("No dehydrated device available")
+ class PutBody(RequestBodyModel):
+ device_id: StrictStr
+ device_data: DehydratedDeviceDataModel
+ initial_device_display_name: Optional[StrictStr]
+
async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- submission = parse_json_object_from_request(request)
+ submission = parse_and_validate_json_object_from_request(request, self.PutBody)
requester = await self.auth.get_user_by_req(request)
- if "device_data" not in submission:
- raise errors.SynapseError(
- 400,
- "device_data missing",
- errcode=errors.Codes.MISSING_PARAM,
- )
- elif not isinstance(submission["device_data"], dict):
- raise errors.SynapseError(
- 400,
- "device_data must be an object",
- errcode=errors.Codes.INVALID_PARAM,
- )
-
device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(),
- submission["device_data"],
- submission.get("initial_device_display_name", None),
+ submission.device_data,
+ submission.initial_device_display_name,
)
return 200, {"device_id": device_id}
@@ -300,28 +316,18 @@ class ClaimDehydratedDeviceServlet(RestServlet):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
+ class PostBody(RequestBodyModel):
+ device_id: StrictStr
+
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- submission = parse_json_object_from_request(request)
-
- if "device_id" not in submission:
- raise errors.SynapseError(
- 400,
- "device_id missing",
- errcode=errors.Codes.MISSING_PARAM,
- )
- elif not isinstance(submission["device_id"], str):
- raise errors.SynapseError(
- 400,
- "device_id must be a string",
- errcode=errors.Codes.INVALID_PARAM,
- )
+ submission = parse_and_validate_json_object_from_request(request, self.PostBody)
result = await self.device_handler.rehydrate_device(
requester.user.to_string(),
self.auth.get_access_token_from_request(request),
- submission["device_id"],
+ submission.device_id,
)
return 200, result
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 916f5230f1..782e7d14e8 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -50,7 +50,9 @@ class EventStreamRestServlet(RestServlet):
raise SynapseError(400, "Guest users must specify room_id param")
room_id = parse_string(request, "room_id")
- pagin_config = await PaginationConfig.from_request(self.store, request)
+ pagin_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in args:
try:
diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py
index cfadcb8e50..9b1bb8b521 100644
--- a/synapse/rest/client/initial_sync.py
+++ b/synapse/rest/client/initial_sync.py
@@ -39,7 +39,9 @@ class InitialSyncRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
args: Dict[bytes, List[bytes]] = request.args # type: ignore
as_client_event = b"raw" not in args
- pagination_config = await PaginationConfig.from_request(self.store, request)
+ pagination_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index f3ff156abe..18a282b22c 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -15,8 +15,8 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from synapse.api.constants import ReceiptTypes
-from synapse.api.errors import SynapseError
+from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
+from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@@ -43,13 +43,13 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
+ self._main_store = hs.get_datastores().main
self._known_receipt_types = {
ReceiptTypes.READ,
ReceiptTypes.READ_PRIVATE,
ReceiptTypes.FULLY_READ,
}
- self._msc3771_enabled = hs.config.experimental.msc3771_enabled
async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
@@ -66,13 +66,29 @@ class ReceiptRestServlet(RestServlet):
# Pull the thread ID, if one exists.
thread_id = None
- if self._msc3771_enabled:
- if "thread_id" in body:
- thread_id = body.get("thread_id")
- if not thread_id or not isinstance(thread_id, str):
- raise SynapseError(
- 400, "thread_id field must be a non-empty string"
- )
+ if "thread_id" in body:
+ thread_id = body.get("thread_id")
+ if not thread_id or not isinstance(thread_id, str):
+ raise SynapseError(
+ 400,
+ "thread_id field must be a non-empty string",
+ Codes.INVALID_PARAM,
+ )
+
+ if receipt_type == ReceiptTypes.FULLY_READ:
+ raise SynapseError(
+ 400,
+ f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.",
+ Codes.INVALID_PARAM,
+ )
+
+ # Ensure the event ID roughly correlates to the thread ID.
+ if not await self._is_event_in_thread(event_id, thread_id):
+ raise SynapseError(
+ 400,
+ f"event_id {event_id} is not related to thread {thread_id}",
+ Codes.INVALID_PARAM,
+ )
await self.presence_handler.bump_presence_active_time(requester.user)
@@ -93,6 +109,46 @@ class ReceiptRestServlet(RestServlet):
return 200, {}
+ async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool:
+ """
+ The event must be related to the thread ID (in a vague sense) to ensure
+ clients aren't sending bogus receipts.
+
+ A thread ID is considered valid for a given event E if:
+
+ 1. E has a thread relation which matches the thread ID;
+ 2. E has another event which has a thread relation to E matching the
+ thread ID; or
+ 3. E is recursively related (via any rel_type) to an event which
+ satisfies 1 or 2.
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ It is valid to send a receipt for thread A on A, B, C, D, or E.
+
+ It is valid to send a receipt for the main timeline on A, D, and E.
+
+ Args:
+ event_id: The event ID to check.
+ thread_id: The thread ID the event is potentially part of.
+
+ Returns:
+ True if the event belongs to the given thread, otherwise False.
+ """
+
+ # If the receipt is on the main timeline, it is enough to check whether
+ # the event is directly related to a thread.
+ if thread_id == MAIN_TIMELINE:
+ return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id)
+
+ # Otherwise, check if the event is directly part of a thread, or is the
+ # root message (or related to the root message) of a thread.
+ return thread_id == await self._main_store.get_thread_id_for_receipts(event_id)
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 7a25de5c85..9dd59196d9 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -13,13 +13,17 @@
# limitations under the License.
import logging
+import re
from typing import TYPE_CHECKING, Optional, Tuple
+from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import JsonDict, StreamToken
+from synapse.storage.databases.main.relations import ThreadsNextBatch
+from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -41,9 +45,8 @@ class RelationPaginationServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
- self.store = hs.get_datastores().main
+ self._store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()
- self._msc3715_enabled = hs.config.experimental.msc3715_enabled
async def on_GET(
self,
@@ -55,49 +58,63 @@ class RelationPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- limit = parse_integer(request, "limit", default=5)
- # Fetch the direction parameter, if provided.
- #
- # TODO Use PaginationConfig.from_request when the unstable parameter is
- # no longer needed.
- direction = parse_string(request, "dir", allowed_values=["f", "b"])
- if direction is None:
- if self._msc3715_enabled:
- direction = parse_string(
- request,
- "org.matrix.msc3715.dir",
- default="b",
- allowed_values=["f", "b"],
- )
- else:
- direction = "b"
- from_token_str = parse_string(request, "from")
- to_token_str = parse_string(request, "to")
-
- # Return the relations
- from_token = None
- if from_token_str:
- from_token = await StreamToken.from_string(self.store, from_token_str)
- to_token = None
- if to_token_str:
- to_token = await StreamToken.from_string(self.store, to_token_str)
+ pagination_config = await PaginationConfig.from_request(
+ self._store, request, default_limit=5, default_dir="b"
+ )
# The unstable version of this API returns an extra field for client
# compatibility, see https://github.com/matrix-org/synapse/issues/12930.
assert request.path is not None
include_original_event = request.path.startswith(b"/_matrix/client/unstable/")
+ # Return the relations
result = await self._relations_handler.get_relations(
requester=requester,
event_id=parent_id,
room_id=room_id,
+ pagin_config=pagination_config,
+ include_original_event=include_original_event,
relation_type=relation_type,
event_type=event_type,
+ )
+
+ return 200, result
+
+
+class ThreadsServlet(RestServlet):
+ PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),)
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastores().main
+ self._relations_handler = hs.get_relations_handler()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token_str = parse_string(request, "from")
+ include = parse_string(
+ request,
+ "include",
+ default=ThreadsListInclude.all.value,
+ allowed_values=[v.value for v in ThreadsListInclude],
+ )
+
+ # Return the relations
+ from_token = None
+ if from_token_str:
+ from_token = ThreadsNextBatch.from_string(from_token_str)
+
+ result = await self._relations_handler.get_threads(
+ requester=requester,
+ room_id=room_id,
+ include=ThreadsListInclude(include),
limit=limit,
- direction=direction,
from_token=from_token,
- to_token=to_token,
- include_original_event=include_original_event,
)
return 200, result
@@ -105,3 +122,4 @@ class RelationPaginationServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
+ ThreadsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index b6dedbed04..01e5079963 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -729,7 +729,9 @@ class RoomInitialSyncRestServlet(RestServlet):
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = await PaginationConfig.from_request(self.store, request)
+ pagination_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index c2989765ce..8a16459105 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -100,6 +100,7 @@ class SyncRestServlet(RestServlet):
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
self._msc2654_enabled = hs.config.experimental.msc2654_enabled
+ self._msc3773_enabled = hs.config.experimental.msc3773_enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
@@ -509,6 +510,12 @@ class SyncRestServlet(RestServlet):
ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
+ if room.unread_thread_notifications:
+ result["unread_thread_notifications"] = room.unread_thread_notifications
+ if self._msc3773_enabled:
+ result[
+ "org.matrix.msc3773.unread_thread_notifications"
+ ] = room.unread_thread_notifications
result["summary"] = room.summary
if self._msc2654_enabled:
result["org.matrix.msc2654.unread_count"] = room.unread_count
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index c95b0d6f19..4e1fd2bbe7 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -75,6 +75,8 @@ class VersionsRestServlet(RestServlet):
"r0.6.1",
"v1.1",
"v1.2",
+ "v1.3",
+ "v1.4",
],
# as per MSC1497:
"unstable_features": {
@@ -103,8 +105,9 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3030": self.config.experimental.msc3030_enabled,
# Adds support for thread relations, per MSC3440.
"org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above
- # Support for thread read receipts.
- "org.matrix.msc3771": self.config.experimental.msc3771_enabled,
+ # Support for thread read receipts & notification counts.
+ "org.matrix.msc3771": True,
+ "org.matrix.msc3773": self.config.experimental.msc3773_enabled,
# Allows moderators to fetch redacted event content as described in MSC2815
"fi.mau.msc2815": self.config.experimental.msc2815_enabled,
# Adds support for login token requests as per MSC3882
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 2177b46c9e..827afd868d 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -139,65 +139,72 @@ class OEmbedProvider:
try:
# oEmbed responses *must* be UTF-8 according to the spec.
oembed = json_decoder.decode(raw_body.decode("utf-8"))
+ except ValueError:
+ return OEmbedResult({}, None, None)
- # The version is a required string field, but not always provided,
- # or sometimes provided as a float. Be lenient.
- oembed_version = oembed.get("version", "1.0")
- if oembed_version != "1.0" and oembed_version != 1:
- raise RuntimeError(f"Invalid oEmbed version: {oembed_version}")
+ # The version is a required string field, but not always provided,
+ # or sometimes provided as a float. Be lenient.
+ oembed_version = oembed.get("version", "1.0")
+ if oembed_version != "1.0" and oembed_version != 1:
+ return OEmbedResult({}, None, None)
- # Ensure the cache age is None or an int.
- cache_age = oembed.get("cache_age")
- if cache_age:
- cache_age = int(cache_age) * 1000
-
- # The results.
- open_graph_response = {
- "og:url": url,
- }
-
- title = oembed.get("title")
- if title:
- open_graph_response["og:title"] = title
-
- author_name = oembed.get("author_name")
+ # Attempt to parse the cache age, if possible.
+ try:
+ cache_age = int(oembed.get("cache_age")) * 1000
+ except (TypeError, ValueError):
+ # If the cache age cannot be parsed (e.g. wrong type or invalid
+ # string), ignore it.
+ cache_age = None
- # Use the provider name and as the site.
- provider_name = oembed.get("provider_name")
- if provider_name:
- open_graph_response["og:site_name"] = provider_name
+ # The oEmbed response converted to Open Graph.
+ open_graph_response: JsonDict = {"og:url": url}
- # If a thumbnail exists, use it. Note that dimensions will be calculated later.
- if "thumbnail_url" in oembed:
- open_graph_response["og:image"] = oembed["thumbnail_url"]
+ title = oembed.get("title")
+ if title and isinstance(title, str):
+ open_graph_response["og:title"] = title
- # Process each type separately.
- oembed_type = oembed["type"]
- if oembed_type == "rich":
- calc_description_and_urls(open_graph_response, oembed["html"])
-
- elif oembed_type == "photo":
- # If this is a photo, use the full image, not the thumbnail.
- open_graph_response["og:image"] = oembed["url"]
+ author_name = oembed.get("author_name")
+ if not isinstance(author_name, str):
+ author_name = None
- elif oembed_type == "video":
- open_graph_response["og:type"] = "video.other"
+ # Use the provider name and as the site.
+ provider_name = oembed.get("provider_name")
+ if provider_name and isinstance(provider_name, str):
+ open_graph_response["og:site_name"] = provider_name
+
+ # If a thumbnail exists, use it. Note that dimensions will be calculated later.
+ thumbnail_url = oembed.get("thumbnail_url")
+ if thumbnail_url and isinstance(thumbnail_url, str):
+ open_graph_response["og:image"] = thumbnail_url
+
+ # Process each type separately.
+ oembed_type = oembed.get("type")
+ if oembed_type == "rich":
+ html = oembed.get("html")
+ if isinstance(html, str):
+ calc_description_and_urls(open_graph_response, html)
+
+ elif oembed_type == "photo":
+ # If this is a photo, use the full image, not the thumbnail.
+ url = oembed.get("url")
+ if url and isinstance(url, str):
+ open_graph_response["og:image"] = url
+
+ elif oembed_type == "video":
+ open_graph_response["og:type"] = "video.other"
+ html = oembed.get("html")
+ if html and isinstance(html, str):
calc_description_and_urls(open_graph_response, oembed["html"])
- open_graph_response["og:video:width"] = oembed["width"]
- open_graph_response["og:video:height"] = oembed["height"]
-
- elif oembed_type == "link":
- open_graph_response["og:type"] = "website"
+ for size in ("width", "height"):
+ val = oembed.get(size)
+ if val is not None and isinstance(val, int):
+ open_graph_response[f"og:video:{size}"] = val
- else:
- raise RuntimeError(f"Unknown oEmbed type: {oembed_type}")
+ elif oembed_type == "link":
+ open_graph_response["og:type"] = "website"
- except Exception as e:
- # Trap any exception and let the code follow as usual.
- logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
- open_graph_response = {}
- author_name = None
- cache_age = None
+ else:
+ logger.warning("Unknown oEmbed type: %s", oembed_type)
return OEmbedResult(open_graph_response, author_name, cache_age)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b4469eb964..7bb21f8f81 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -94,7 +94,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"event_search": "event_search_event_id_idx",
"local_media_repository_thumbnails": "local_media_repository_thumbnails_method_idx",
"remote_media_cache_thumbnails": "remote_media_repository_thumbnails_method_idx",
- "event_push_summary": "event_push_summary_unique_index",
+ "event_push_summary": "event_push_summary_unique_index2",
"receipts_linearized": "receipts_linearized_unique_index",
"receipts_graph": "receipts_graph_unique_index",
}
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 3b8ed1f7ee..ed0be4abe5 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -244,6 +244,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
# redacted.
self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,))
self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,))
+ self._attempt_to_invalidate_cache("get_thread_id", (redacts,))
+ self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,))
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
@@ -259,9 +261,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,))
self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,))
- self._attempt_to_invalidate_cache(
- "get_mutual_event_relations_for_rel_type", (relates_to,)
- )
+ self._attempt_to_invalidate_cache("get_threads", (room_id,))
async def invalidate_cache_and_stream(
self, cache_name: str, keys: Tuple[Any, ...]
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 72cf91eb39..f070e6e88a 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -88,7 +88,7 @@ from typing import (
import attr
-from synapse.api.constants import ReceiptTypes
+from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@@ -119,6 +119,32 @@ DEFAULT_HIGHLIGHT_ACTION: List[Union[dict, str]] = [
]
+@attr.s(slots=True, auto_attribs=True)
+class _RoomReceipt:
+ """
+ HttpPushAction instances include the information used to generate HTTP
+ requests to a push gateway.
+ """
+
+ unthreaded_stream_ordering: int = 0
+ # threaded_stream_ordering includes the main pseudo-thread.
+ threaded_stream_ordering: Dict[str, int] = attr.Factory(dict)
+
+ def is_unread(self, thread_id: str, stream_ordering: int) -> bool:
+ """Returns True if the stream ordering is unread according to the receipt information."""
+
+ # Only include push actions with a stream ordering after both the unthreaded
+ # and threaded receipt. Properly handles a user without any receipts present.
+ return (
+ self.unthreaded_stream_ordering < stream_ordering
+ and self.threaded_stream_ordering.get(thread_id, 0) < stream_ordering
+ )
+
+
+# A _RoomReceipt with no receipts in it.
+MISSING_ROOM_RECEIPT = _RoomReceipt()
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class HttpPushAction:
"""
@@ -157,7 +183,7 @@ class UserPushAction(EmailPushAction):
@attr.s(slots=True, auto_attribs=True)
class NotifCounts:
"""
- The per-user, per-room count of notifications. Used by sync and push.
+ The per-user, per-room, per-thread count of notifications. Used by sync and push.
"""
notify_count: int = 0
@@ -165,6 +191,21 @@ class NotifCounts:
highlight_count: int = 0
+@attr.s(slots=True, auto_attribs=True)
+class RoomNotifCounts:
+ """
+ The per-user, per-room count of notifications. Used by sync and push.
+ """
+
+ main_timeline: NotifCounts
+ # Map of thread ID to the notification counts.
+ threads: Dict[str, NotifCounts]
+
+ def __len__(self) -> int:
+ # To properly account for the amount of space in any caches.
+ return len(self.threads) + 1
+
+
def _serialize_action(
actions: Collection[Union[Mapping, str]], is_highlight: bool
) -> str:
@@ -384,12 +425,12 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return result
- @cached(tree=True, max_entries=5000)
+ @cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
room_id: str,
user_id: str,
- ) -> NotifCounts:
+ ) -> RoomNotifCounts:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after their latest read receipt.
@@ -402,8 +443,9 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to retrieve the counts for.
Returns
- A NotifCounts object containing the notification count, the highlight count
- and the unread message count.
+ A RoomNotifCounts object containing the notification count, the
+ highlight count and the unread message count for both the main timeline
+ and threads.
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
@@ -417,7 +459,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: LoggingTransaction,
room_id: str,
user_id: str,
- ) -> NotifCounts:
+ ) -> RoomNotifCounts:
# Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_unthreaded_receipt_for_user_txn(
txn,
@@ -451,8 +493,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: LoggingTransaction,
room_id: str,
user_id: str,
- receipt_stream_ordering: int,
- ) -> NotifCounts:
+ unthreaded_receipt_stream_ordering: int,
+ ) -> RoomNotifCounts:
"""Get the number of unread messages for a user/room that have happened
since the given stream ordering.
@@ -460,78 +502,204 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: The database transaction.
room_id: The room ID to get unread counts for.
user_id: The user ID to get unread counts for.
- receipt_stream_ordering: The stream ordering of the user's latest
- receipt in the room. If there are no receipts, the stream ordering
- of the user's join event.
+ unthreaded_receipt_stream_ordering: The stream ordering of the user's latest
+ unthreaded receipt in the room. If there are no unthreaded receipts,
+ the stream ordering of the user's join event.
- Returns
- A NotifCounts object containing the notification count, the highlight count
- and the unread message count.
+ Returns:
+ A RoomNotifCounts object containing the notification count, the
+ highlight count and the unread message count for both the main timeline
+ and threads.
"""
- counts = NotifCounts()
+ main_counts = NotifCounts()
+ thread_counts: Dict[str, NotifCounts] = {}
+
+ def _get_thread(thread_id: str) -> NotifCounts:
+ if thread_id == MAIN_TIMELINE:
+ return main_counts
+ return thread_counts.setdefault(thread_id, NotifCounts())
+
+ receipt_types_clause, receipts_args = make_in_list_sql_clause(
+ self.database_engine,
+ "receipt_type",
+ (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
+ )
# First we pull the counts from the summary table.
#
- # We check that `last_receipt_stream_ordering` matches the stream
- # ordering given. If it doesn't match then a new read receipt has arrived and
- # we haven't yet updated the counts in `event_push_summary` to reflect
- # that; in that case we simply ignore `event_push_summary` counts
- # and do a manual count of all of the rows in the `event_push_actions` table
- # for this user/room.
+ # We check that `last_receipt_stream_ordering` matches the stream ordering of the
+ # latest receipt for the thread (which may be either the unthreaded read receipt
+ # or the threaded read receipt).
+ #
+ # If it doesn't match then a new read receipt has arrived and we haven't yet
+ # updated the counts in `event_push_summary` to reflect that; in that case we
+ # simply ignore `event_push_summary` counts.
+ #
+ # We then do a manual count of all the rows in the `event_push_actions` table
+ # for any user/room/thread which did not have a valid summary found.
#
- # If `last_receipt_stream_ordering` is null then that means it's up to
- # date (as the row was written by an older version of Synapse that
+ # If `last_receipt_stream_ordering` is null then that means it's up-to-date
+ # (as the row was written by an older version of Synapse that
# updated `event_push_summary` synchronously when persisting a new read
# receipt).
txn.execute(
- """
- SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
+ f"""
+ SELECT notif_count, COALESCE(unread_count, 0), thread_id
FROM event_push_summary
+ LEFT JOIN (
+ SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering
+ FROM receipts_linearized
+ LEFT JOIN events USING (room_id, event_id)
+ WHERE
+ user_id = ?
+ AND room_id = ?
+ AND stream_ordering > ?
+ AND {receipt_types_clause}
+ GROUP BY thread_id
+ ) AS receipts USING (thread_id)
WHERE room_id = ? AND user_id = ?
AND (
- (last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
- OR last_receipt_stream_ordering = ?
- )
+ (last_receipt_stream_ordering IS NULL AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?))
+ OR last_receipt_stream_ordering = COALESCE(threaded_receipt_stream_ordering, ?)
+ ) AND (notif_count != 0 OR COALESCE(unread_count, 0) != 0)
""",
- (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
+ (
+ user_id,
+ room_id,
+ unthreaded_receipt_stream_ordering,
+ *receipts_args,
+ room_id,
+ user_id,
+ unthreaded_receipt_stream_ordering,
+ unthreaded_receipt_stream_ordering,
+ ),
)
- row = txn.fetchone()
-
- summary_stream_ordering = 0
- if row:
- summary_stream_ordering = row[0]
- counts.notify_count += row[1]
- counts.unread_count += row[2]
+ summarised_threads = set()
+ for notif_count, unread_count, thread_id in txn:
+ summarised_threads.add(thread_id)
+ counts = _get_thread(thread_id)
+ counts.notify_count += notif_count
+ counts.unread_count += unread_count
# Next we need to count highlights, which aren't summarised
- sql = """
- SELECT COUNT(*) FROM event_push_actions
+ sql = f"""
+ SELECT COUNT(*), thread_id FROM event_push_actions
+ LEFT JOIN (
+ SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering
+ FROM receipts_linearized
+ LEFT JOIN events USING (room_id, event_id)
+ WHERE
+ user_id = ?
+ AND room_id = ?
+ AND stream_ordering > ?
+ AND {receipt_types_clause}
+ GROUP BY thread_id
+ ) AS receipts USING (thread_id)
WHERE user_id = ?
AND room_id = ?
- AND stream_ordering > ?
+ AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?)
AND highlight = 1
+ GROUP BY thread_id
"""
- txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
- row = txn.fetchone()
- if row:
- counts.highlight_count += row[0]
+ txn.execute(
+ sql,
+ (
+ user_id,
+ room_id,
+ unthreaded_receipt_stream_ordering,
+ *receipts_args,
+ user_id,
+ room_id,
+ unthreaded_receipt_stream_ordering,
+ ),
+ )
+ for highlight_count, thread_id in txn:
+ _get_thread(thread_id).highlight_count += highlight_count
+
+ # For threads which were summarised we need to count actions since the last
+ # rotation.
+ thread_id_clause, thread_id_args = make_in_list_sql_clause(
+ self.database_engine, "thread_id", summarised_threads
+ )
+
+ # The (inclusive) event stream ordering that was previously summarised.
+ rotated_upto_stream_ordering = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ table="event_push_summary_stream_ordering",
+ keyvalues={},
+ retcol="stream_ordering",
+ )
+
+ unread_counts = self._get_notif_unread_count_for_user_room(
+ txn, room_id, user_id, rotated_upto_stream_ordering
+ )
+ for notif_count, unread_count, thread_id in unread_counts:
+ if thread_id not in summarised_threads:
+ continue
+
+ if thread_id == MAIN_TIMELINE:
+ counts.notify_count += notif_count
+ counts.unread_count += unread_count
+ elif thread_id in thread_counts:
+ thread_counts[thread_id].notify_count += notif_count
+ thread_counts[thread_id].unread_count += unread_count
+ else:
+ # Previous thread summaries of 0 are discarded above.
+ #
+ # TODO If empty summaries are deleted this can be removed.
+ thread_counts[thread_id] = NotifCounts(
+ notify_count=notif_count,
+ unread_count=unread_count,
+ highlight_count=0,
+ )
# Finally we need to count push actions that aren't included in the
# summary returned above. This might be due to recent events that haven't
# been summarised yet or the summary is out of date due to a recent read
# receipt.
- start_unread_stream_ordering = max(
- receipt_stream_ordering, summary_stream_ordering
- )
- notify_count, unread_count = self._get_notif_unread_count_for_user_room(
- txn, room_id, user_id, start_unread_stream_ordering
+ sql = f"""
+ SELECT
+ COUNT(CASE WHEN notif = 1 THEN 1 END),
+ COUNT(CASE WHEN unread = 1 THEN 1 END),
+ thread_id
+ FROM event_push_actions
+ LEFT JOIN (
+ SELECT thread_id, MAX(stream_ordering) AS threaded_receipt_stream_ordering
+ FROM receipts_linearized
+ LEFT JOIN events USING (room_id, event_id)
+ WHERE
+ user_id = ?
+ AND room_id = ?
+ AND stream_ordering > ?
+ AND {receipt_types_clause}
+ GROUP BY thread_id
+ ) AS receipts USING (thread_id)
+ WHERE user_id = ?
+ AND room_id = ?
+ AND stream_ordering > COALESCE(threaded_receipt_stream_ordering, ?)
+ AND NOT {thread_id_clause}
+ GROUP BY thread_id
+ """
+ txn.execute(
+ sql,
+ (
+ user_id,
+ room_id,
+ unthreaded_receipt_stream_ordering,
+ *receipts_args,
+ user_id,
+ room_id,
+ unthreaded_receipt_stream_ordering,
+ *thread_id_args,
+ ),
)
+ for notif_count, unread_count, thread_id in txn:
+ counts = _get_thread(thread_id)
+ counts.notify_count += notif_count
+ counts.unread_count += unread_count
- counts.notify_count += notify_count
- counts.unread_count += unread_count
-
- return counts
+ return RoomNotifCounts(main_counts, thread_counts)
def _get_notif_unread_count_for_user_room(
self,
@@ -540,7 +708,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
stream_ordering: int,
max_stream_ordering: Optional[int] = None,
- ) -> Tuple[int, int]:
+ thread_id: Optional[str] = None,
+ ) -> List[Tuple[int, int, str]]:
"""Returns the notify and unread counts from `event_push_actions` for
the given user/room in the given range.
@@ -554,45 +723,55 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
stream_ordering: The (exclusive) minimum stream ordering to consider.
max_stream_ordering: The (inclusive) maximum stream ordering to consider.
If this is not given, then no maximum is applied.
+ thread_id: The thread ID to fetch unread counts for. If this is not provided
+ then the results for *all* threads is returned.
+
+ Note that if this is provided the resulting list will only have 0 or
+ 1 tuples in it.
Return:
- A tuple of the notif count and unread count in the given range.
+ A tuple of the notif count and unread count in the given range for
+ each thread.
"""
# If there have been no events in the room since the stream ordering,
# there can't be any push actions either.
if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
- return 0, 0
+ return []
- clause = ""
+ stream_ordering_clause = ""
args = [user_id, room_id, stream_ordering]
if max_stream_ordering is not None:
- clause = "AND ea.stream_ordering <= ?"
+ stream_ordering_clause = "AND ea.stream_ordering <= ?"
args.append(max_stream_ordering)
# If the max stream ordering is less than the min stream ordering,
# then obviously there are zero push actions in that range.
if max_stream_ordering <= stream_ordering:
- return 0, 0
+ return []
+
+ # Either limit the results to a specific thread or fetch all threads.
+ thread_id_clause = ""
+ if thread_id is not None:
+ thread_id_clause = "AND thread_id = ?"
+ args.append(thread_id)
sql = f"""
SELECT
COUNT(CASE WHEN notif = 1 THEN 1 END),
- COUNT(CASE WHEN unread = 1 THEN 1 END)
- FROM event_push_actions ea
- WHERE user_id = ?
+ COUNT(CASE WHEN unread = 1 THEN 1 END),
+ thread_id
+ FROM event_push_actions ea
+ WHERE user_id = ?
AND room_id = ?
AND ea.stream_ordering > ?
- {clause}
+ {stream_ordering_clause}
+ {thread_id_clause}
+ GROUP BY thread_id
"""
txn.execute(sql, args)
- row = txn.fetchone()
-
- if row:
- return cast(Tuple[int, int], row)
-
- return 0, 0
+ return cast(List[Tuple[int, int, str]], txn.fetchall())
async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int
@@ -609,7 +788,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def _get_receipts_by_room_txn(
self, txn: LoggingTransaction, user_id: str
- ) -> Dict[str, int]:
+ ) -> Dict[str, _RoomReceipt]:
"""
Generate a map of room ID to the latest stream ordering that has been
read by the given user.
@@ -619,7 +798,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to fetch receipts for.
Returns:
- A map of room ID to stream ordering for all rooms the user has a receipt in.
+ A map including all rooms the user is in with a receipt. It maps
+ room IDs to _RoomReceipt instances
"""
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
@@ -628,20 +808,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
sql = f"""
- SELECT room_id, MAX(stream_ordering)
+ SELECT room_id, thread_id, MAX(stream_ordering)
FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
WHERE {receipt_types_clause}
AND user_id = ?
- GROUP BY room_id
+ GROUP BY room_id, thread_id
"""
args.extend((user_id,))
txn.execute(sql, args)
- return {
- room_id: latest_stream_ordering
- for room_id, latest_stream_ordering in txn.fetchall()
- }
+
+ result: Dict[str, _RoomReceipt] = {}
+ for room_id, thread_id, stream_ordering in txn:
+ room_receipt = result.setdefault(room_id, _RoomReceipt())
+ if thread_id is None:
+ room_receipt.unthreaded_stream_ordering = stream_ordering
+ else:
+ room_receipt.threaded_stream_ordering[thread_id] = stream_ordering
+
+ return result
async def get_unread_push_actions_for_user_in_range_for_http(
self,
@@ -674,9 +860,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_push_actions_txn(
txn: LoggingTransaction,
- ) -> List[Tuple[str, str, int, str, bool]]:
+ ) -> List[Tuple[str, str, str, int, str, bool]]:
sql = """
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, ep.highlight
+ SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
+ ep.actions, ep.highlight
FROM event_push_actions AS ep
WHERE
ep.user_id = ?
@@ -686,7 +873,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
ORDER BY ep.stream_ordering ASC LIMIT ?
"""
txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
- return cast(List[Tuple[str, str, int, str, bool]], txn.fetchall())
+ return cast(List[Tuple[str, str, str, int, str, bool]], txn.fetchall())
push_actions = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
@@ -699,10 +886,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
stream_ordering=stream_ordering,
actions=_deserialize_action(actions, highlight),
)
- for event_id, room_id, stream_ordering, actions, highlight in push_actions
- # Only include push actions with a stream ordering after any receipt, or without any
- # receipt present (invited to but never read rooms).
- if stream_ordering > receipts_by_room.get(room_id, 0)
+ for event_id, room_id, thread_id, stream_ordering, actions, highlight in push_actions
+ if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
+ thread_id, stream_ordering
+ )
]
# Now sort it so it's ordered correctly, since currently it will
@@ -746,10 +933,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def get_push_actions_txn(
txn: LoggingTransaction,
- ) -> List[Tuple[str, str, int, str, bool, int]]:
+ ) -> List[Tuple[str, str, str, int, str, bool, int]]:
sql = """
- SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions,
- ep.highlight, e.received_ts
+ SELECT ep.event_id, ep.room_id, ep.thread_id, ep.stream_ordering,
+ ep.actions, ep.highlight, e.received_ts
FROM event_push_actions AS ep
INNER JOIN events AS e USING (room_id, event_id)
WHERE
@@ -760,7 +947,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
ORDER BY ep.stream_ordering DESC LIMIT ?
"""
txn.execute(sql, (user_id, min_stream_ordering, max_stream_ordering, limit))
- return cast(List[Tuple[str, str, int, str, bool, int]], txn.fetchall())
+ return cast(List[Tuple[str, str, str, int, str, bool, int]], txn.fetchall())
push_actions = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
@@ -775,10 +962,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
actions=_deserialize_action(actions, highlight),
received_ts=received_ts,
)
- for event_id, room_id, stream_ordering, actions, highlight, received_ts in push_actions
- # Only include push actions with a stream ordering after any receipt, or without any
- # receipt present (invited to but never read rooms).
- if stream_ordering > receipts_by_room.get(room_id, 0)
+ for event_id, room_id, thread_id, stream_ordering, actions, highlight, received_ts in push_actions
+ if receipts_by_room.get(room_id, MISSING_ROOM_RECEIPT).is_unread(
+ thread_id, stream_ordering
+ )
]
# Now sort it so it's ordered correctly, since currently it will
@@ -1102,7 +1289,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
sql = """
- SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
+ SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering
FROM receipts_linearized AS r
INNER JOIN events AS e USING (event_id)
WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ?
@@ -1123,55 +1310,86 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
limit,
),
)
- rows = cast(List[Tuple[int, str, str, int]], txn.fetchall())
+ rows = cast(List[Tuple[int, str, str, Optional[str], int]], txn.fetchall())
# For each new read receipt we delete push actions from before it and
# recalculate the summary.
- for _, room_id, user_id, stream_ordering in rows:
+ #
+ # Care must be taken of whether it is a threaded or unthreaded receipt.
+ for _, room_id, user_id, thread_id, stream_ordering in rows:
# Only handle our own read receipts.
if not self.hs.is_mine_id(user_id):
continue
+ thread_clause = ""
+ thread_args: Tuple = ()
+ if thread_id is not None:
+ thread_clause = "AND thread_id = ?"
+ thread_args = (thread_id,)
+
+ # For each new read receipt we delete push actions from before it and
+ # recalculate the summary.
txn.execute(
- """
+ f"""
DELETE FROM event_push_actions
WHERE room_id = ?
AND user_id = ?
AND stream_ordering <= ?
AND highlight = 0
+ {thread_clause}
""",
- (room_id, user_id, stream_ordering),
+ (room_id, user_id, stream_ordering, *thread_args),
)
# Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised.
- notif_count, unread_count = self._get_notif_unread_count_for_user_room(
- txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
+ unread_counts = self._get_notif_unread_count_for_user_room(
+ txn,
+ room_id,
+ user_id,
+ stream_ordering,
+ old_rotate_stream_ordering,
+ thread_id,
)
- # First ensure that the existing rows have an updated thread_id field.
- txn.execute(
- """
- UPDATE event_push_summary
- SET thread_id = ?
- WHERE room_id = ? AND user_id = ? AND thread_id is NULL
- """,
- ("main", room_id, user_id),
- )
+ # For an unthreaded receipt, mark the summary for all threads in the room
+ # as cleared.
+ if thread_id is None:
+ self.db_pool.simple_update_txn(
+ txn,
+ table="event_push_summary",
+ keyvalues={"user_id": user_id, "room_id": room_id},
+ updatevalues={
+ "notif_count": 0,
+ "unread_count": 0,
+ "stream_ordering": old_rotate_stream_ordering,
+ "last_receipt_stream_ordering": stream_ordering,
+ },
+ )
- # Replace the previous summary with the new counts.
- #
- # TODO(threads): Upsert per-thread instead of setting them all to main.
- self.db_pool.simple_upsert_txn(
+ # For a threaded receipt, we *always* want to update that receipt,
+ # event if there are no new notifications in that thread. This ensures
+ # the stream_ordering & last_receipt_stream_ordering are updated.
+ elif not unread_counts:
+ unread_counts = [(0, 0, thread_id)]
+
+ # Then any updated threads get their notification count and unread
+ # count updated.
+ self.db_pool.simple_update_many_txn(
txn,
table="event_push_summary",
- keyvalues={"room_id": room_id, "user_id": user_id, "thread_id": "main"},
- values={
- "notif_count": notif_count,
- "unread_count": unread_count,
- "stream_ordering": old_rotate_stream_ordering,
- "last_receipt_stream_ordering": stream_ordering,
- },
+ key_names=("room_id", "user_id", "thread_id"),
+ key_values=[(room_id, user_id, row[2]) for row in unread_counts],
+ value_names=(
+ "notif_count",
+ "unread_count",
+ "stream_ordering",
+ "last_receipt_stream_ordering",
+ ),
+ value_values=[
+ (row[0], row[1], old_rotate_stream_ordering, stream_ordering)
+ for row in unread_counts
+ ],
)
# We always update `event_push_summary_last_receipt_stream_id` to
@@ -1259,23 +1477,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Calculate the new counts that should be upserted into event_push_summary
sql = """
- SELECT user_id, room_id,
+ SELECT user_id, room_id, thread_id,
coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering
FROM (
- SELECT user_id, room_id, count(*) as cnt,
+ SELECT user_id, room_id, thread_id, count(*) as cnt,
max(ea.stream_ordering) as stream_ordering
FROM event_push_actions AS ea
- LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
AND (
old.last_receipt_stream_ordering IS NULL
OR old.last_receipt_stream_ordering < ea.stream_ordering
)
AND %s = 1
- GROUP BY user_id, room_id
+ GROUP BY user_id, room_id, thread_id
) AS upd
- LEFT JOIN event_push_summary AS old USING (user_id, room_id)
+ LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
"""
# First get the count of unread messages.
@@ -1289,11 +1507,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
- summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
+ summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {}
for row in txn:
- summaries[(row[0], row[1])] = _EventPushSummary(
- unread_count=row[2],
- stream_ordering=row[3],
+ summaries[(row[0], row[1], row[2])] = _EventPushSummary(
+ unread_count=row[3],
+ stream_ordering=row[4],
notif_count=0,
)
@@ -1304,48 +1522,36 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
for row in txn:
- if (row[0], row[1]) in summaries:
- summaries[(row[0], row[1])].notif_count = row[2]
+ if (row[0], row[1], row[2]) in summaries:
+ summaries[(row[0], row[1], row[2])].notif_count = row[3]
else:
# Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room)
# tuple to complete.
- summaries[(row[0], row[1])] = _EventPushSummary(
+ summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=0,
- stream_ordering=row[3],
- notif_count=row[2],
+ stream_ordering=row[4],
+ notif_count=row[3],
)
logger.info("Rotating notifications, handling %d rows", len(summaries))
- # Ensure that any updated threads have an updated thread_id.
- txn.execute_batch(
- """
- UPDATE event_push_summary
- SET thread_id = ?
- WHERE room_id = ? AND user_id = ? AND thread_id is NULL
- """,
- [("main", room_id, user_id) for user_id, room_id in summaries],
- )
- self.db_pool.simple_update_many_txn(
- txn,
- table="event_push_summary",
- key_names=("user_id", "room_id", "thread_id"),
- key_values=[(user_id, room_id, None) for user_id, room_id in summaries],
- value_names=("thread_id",),
- value_values=[("main",) for _ in summaries],
- )
-
- # TODO(threads): Update on a per-thread basis.
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
key_names=("user_id", "room_id", "thread_id"),
- key_values=[(user_id, room_id, "main") for user_id, room_id in summaries],
+ key_values=[
+ (user_id, room_id, thread_id)
+ for user_id, room_id, thread_id in summaries
+ ],
value_names=("notif_count", "unread_count", "stream_ordering"),
value_values=[
- (summary.notif_count, summary.unread_count, summary.stream_ordering)
+ (
+ summary.notif_count,
+ summary.unread_count,
+ summary.stream_ordering,
+ )
for summary in summaries.values()
],
)
@@ -1356,7 +1562,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
async def _remove_old_push_actions_that_have_rotated(self) -> None:
- """Clear out old push actions that have been summarised."""
+ """
+ Clear out old push actions that have been summarised (and are older than
+ 1 day ago).
+ """
# We want to clear out anything that is older than a day that *has* already
# been rotated.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3e15827986..6698cbf664 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -35,7 +35,7 @@ import attr
from prometheus_client import Counter
import synapse.metrics
-from synapse.api.constants import EventContentFields, EventTypes
+from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
@@ -1616,7 +1616,7 @@ class PersistEventsStore:
)
# Remove from relations table.
- self._handle_redact_relations(txn, event.redacts)
+ self._handle_redact_relations(txn, event.room_id, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
@@ -1866,6 +1866,34 @@ class PersistEventsStore:
},
)
+ if relation.rel_type == RelationTypes.THREAD:
+ # Upsert into the threads table, but only overwrite the value if the
+ # new event is of a later topological order OR if the topological
+ # ordering is equal, but the stream ordering is later.
+ sql = """
+ INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering)
+ VALUES (?, ?, ?, ?, ?)
+ ON CONFLICT (room_id, thread_id)
+ DO UPDATE SET
+ latest_event_id = excluded.latest_event_id,
+ topological_ordering = excluded.topological_ordering,
+ stream_ordering = excluded.stream_ordering
+ WHERE
+ threads.topological_ordering <= excluded.topological_ordering AND
+ threads.stream_ordering < excluded.stream_ordering
+ """
+
+ txn.execute(
+ sql,
+ (
+ event.room_id,
+ relation.parent_id,
+ event.event_id,
+ event.depth,
+ event.internal_metadata.stream_ordering,
+ ),
+ )
+
def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase
) -> None:
@@ -1989,13 +2017,14 @@ class PersistEventsStore:
txn.execute(sql, (batch_id,))
def _handle_redact_relations(
- self, txn: LoggingTransaction, redacted_event_id: str
+ self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
) -> None:
"""Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database.
Args:
txn
+ room_id: The room ID of the event that was redacted.
redacted_event_id: The event that was redacted.
"""
@@ -2025,9 +2054,7 @@ class PersistEventsStore:
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
self.store._invalidate_cache_and_stream(
- txn,
- self.store.get_mutual_event_relations_for_rel_type,
- (redacted_relates_to,),
+ txn, self.store.get_threads, (room_id,)
)
self.db_pool.simple_delete_txn(
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 7cdc9fe98f..d4104462b5 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -474,7 +474,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
- event_entry_map = await self._get_events_from_cache_or_db(
+ event_entry_map = await self.get_unredacted_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -509,7 +509,9 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
- event_map = await self._get_events_from_cache_or_db([redacted_event_id])
+ event_map = await self.get_unredacted_events_from_cache_or_db(
+ [redacted_event_id]
+ )
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -588,11 +590,16 @@ class EventsWorkerStore(SQLBaseStore):
return events
@cancellable
- async def _get_events_from_cache_or_db(
- self, event_ids: Iterable[str], allow_rejected: bool = False
+ async def get_unredacted_events_from_cache_or_db(
+ self,
+ event_ids: Iterable[str],
+ allow_rejected: bool = False,
) -> Dict[str, EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.
+ Note that the events pulled by this function will not have any redactions
+ applied, and no guarantee is made about the ordering of the events returned.
+
If events are pulled from the database, they will be cached for future lookups.
Unknown events are omitted from the response.
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index ed17b2e70c..51416b2236 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -29,7 +29,6 @@ from typing import (
)
from synapse.api.errors import StoreError
-from synapse.config.homeserver import ExperimentalConfig
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
@@ -63,9 +62,7 @@ logger = logging.getLogger(__name__)
def _load_rules(
- rawrules: List[JsonDict],
- enabled_map: Dict[str, bool],
- experimental_config: ExperimentalConfig,
+ rawrules: List[JsonDict], enabled_map: Dict[str, bool]
) -> FilteredPushRules:
"""Take the DB rows returned from the DB and convert them into a full
`FilteredPushRules` object.
@@ -81,16 +78,9 @@ def _load_rules(
for rawrule in rawrules
]
- push_rules = PushRules(
- ruleslist,
- )
+ push_rules = PushRules(ruleslist)
- filtered_rules = FilteredPushRules(
- push_rules,
- enabled_map,
- msc3786_enabled=experimental_config.msc3786_enabled,
- msc3772_enabled=experimental_config.msc3772_enabled,
- )
+ filtered_rules = FilteredPushRules(push_rules, enabled_map)
return filtered_rules
@@ -170,7 +160,7 @@ class PushRulesWorkerStore(
enabled_map = await self.get_push_rules_enabled_for_user(user_id)
- return _load_rules(rows, enabled_map, self.hs.config.experimental)
+ return _load_rules(rows, enabled_map)
async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]:
results = await self.db_pool.simple_select_list(
@@ -229,9 +219,7 @@ class PushRulesWorkerStore(
results: Dict[str, FilteredPushRules] = {}
for user_id, rules in raw_rules.items():
- results[user_id] = _load_rules(
- rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
- )
+ results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
return results
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 246f78ac1f..dc6989527e 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -418,6 +418,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = db_to_json(row["data"])
+ if row["thread_id"]:
+ receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
results = {
room_id: [results[room_id]] if room_id in results else []
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 898947af95..1de62ee9df 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -14,6 +14,7 @@
import logging
from typing import (
+ TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
@@ -28,19 +29,48 @@ from typing import (
import attr
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import MAIN_TIMELINE, RelationTypes
+from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+ make_in_list_sql_clause,
+)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@attr.s(slots=True, frozen=True, auto_attribs=True)
+class ThreadsNextBatch:
+ topological_ordering: int
+ stream_ordering: int
+
+ def __str__(self) -> str:
+ return f"{self.topological_ordering}_{self.stream_ordering}"
+
+ @classmethod
+ def from_string(cls, string: str) -> "ThreadsNextBatch":
+ """
+ Creates a ThreadsNextBatch from its textual representation.
+ """
+ try:
+ keys = (int(s) for s in string.split("_"))
+ return cls(*keys)
+ except Exception:
+ raise SynapseError(400, "Invalid threads token")
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
class _RelatedEvent:
"""
Contains enough information about a related event in order to properly filter
@@ -56,6 +86,76 @@ class _RelatedEvent:
class RelationsWorkerStore(SQLBaseStore):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_update_handler(
+ "threads_backfill", self._backfill_threads
+ )
+
+ async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int:
+ """Backfill the threads table."""
+
+ def threads_backfill_txn(txn: LoggingTransaction) -> int:
+ last_thread_id = progress.get("last_thread_id", "")
+
+ # Get the latest event in each thread by topo ordering / stream ordering.
+ #
+ # Note that the MAX(event_id) is needed to abide by the rules of group by,
+ # but doesn't actually do anything since there should only be a single event
+ # ID per topo/stream ordering pair.
+ sql = f"""
+ SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id > ? AND
+ relation_type = '{RelationTypes.THREAD}'
+ GROUP BY room_id, relates_to_id
+ ORDER BY relates_to_id
+ LIMIT ?
+ """
+ txn.execute(sql, (last_thread_id, batch_size))
+
+ # No more rows to process.
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ # Insert the rows into the threads table. If a matching thread already exists,
+ # assume it is from a newer event.
+ sql = """
+ INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id)
+ VALUES %s
+ ON CONFLICT (room_id, thread_id)
+ DO NOTHING
+ """
+ if isinstance(txn.database_engine, PostgresEngine):
+ txn.execute_values(sql % ("?",), rows, fetch=False)
+ else:
+ txn.execute_batch(sql % ("(?, ?, ?, ?, ?)",), rows)
+
+ # Mark the progress.
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "threads_backfill", {"last_thread_id": rows[-1][1]}
+ )
+
+ return txn.rowcount
+
+ result = await self.db_pool.runInteraction(
+ "threads_backfill", threads_backfill_txn
+ )
+
+ if not result:
+ await self.db_pool.updates._end_background_update("threads_backfill")
+
+ return result
+
@cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
@@ -384,12 +484,11 @@ class RelationsWorkerStore(SQLBaseStore):
the event will map to None.
"""
- # We only allow edits for `m.room.message` events that have the same sender
- # and event type. We can't assert these things during regular event auth so
- # we have to do the checks post hoc.
+ # We only allow edits for events that have the same sender and event type.
+ # We can't assert these things during regular event auth so we have to do
+ # the checks post hoc.
- # Fetches latest edit that has the same type and sender as the
- # original, and is an `m.room.message`.
+ # Fetches latest edit that has the same type and sender as the original.
if isinstance(self.database_engine, PostgresEngine):
# The `DISTINCT ON` clause will pick the *first* row it encounters,
# so ordering by origin server ts + event ID desc will ensure we get
@@ -405,7 +504,6 @@ class RelationsWorkerStore(SQLBaseStore):
WHERE
%s
AND relation_type = ?
- AND edit.type = 'm.room.message'
ORDER by original.event_id DESC, edit.origin_server_ts DESC, edit.event_id DESC
"""
else:
@@ -424,7 +522,6 @@ class RelationsWorkerStore(SQLBaseStore):
WHERE
%s
AND relation_type = ?
- AND edit.type = 'm.room.message'
ORDER by edit.origin_server_ts, edit.event_id
"""
@@ -779,57 +876,192 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
- @cached(iterable=True)
- async def get_mutual_event_relations_for_rel_type(
- self, event_id: str, relation_type: str
- ) -> Set[Tuple[str, str]]:
- raise NotImplementedError()
+ @cached(tree=True)
+ async def get_threads(
+ self,
+ room_id: str,
+ limit: int = 5,
+ from_token: Optional[ThreadsNextBatch] = None,
+ ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+ """Get a list of thread IDs, ordered by topological ordering of their
+ latest reply.
+
+ Args:
+ room_id: The room the event belongs to.
+ limit: Only fetch the most recent `limit` threads.
+ from_token: Fetch rows from a previous next_batch, or from the start if None.
+
+ Returns:
+ A tuple of:
+ A list of thread root event IDs.
+
+ The next_batch, if one exists.
+ """
+ # Generate the pagination clause, if necessary.
+ #
+ # Find any threads where the latest reply is equal / before the last
+ # thread's topo ordering and earlier in stream ordering.
+ pagination_clause = ""
+ pagination_args: tuple = ()
+ if from_token:
+ pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?"
+ pagination_args = (
+ from_token.topological_ordering,
+ from_token.stream_ordering,
+ )
- @cachedList(
- cached_method_name="get_mutual_event_relations_for_rel_type",
- list_name="relation_types",
- )
- async def get_mutual_event_relations(
- self, event_id: str, relation_types: Collection[str]
- ) -> Dict[str, Set[Tuple[str, str]]]:
+ sql = f"""
+ SELECT thread_id, topological_ordering, stream_ordering
+ FROM threads
+ WHERE
+ room_id = ?
+ {pagination_clause}
+ ORDER BY topological_ordering DESC, stream_ordering DESC
+ LIMIT ?
+ """
+
+ def _get_threads_txn(
+ txn: LoggingTransaction,
+ ) -> Tuple[List[str], Optional[ThreadsNextBatch]]:
+ txn.execute(sql, (room_id, *pagination_args, limit + 1))
+
+ rows = cast(List[Tuple[str, int, int]], txn.fetchall())
+ thread_ids = [r[0] for r in rows]
+
+ # If there are more events, generate the next pagination key from the
+ # last thread which will be returned.
+ next_token = None
+ if len(thread_ids) > limit:
+ last_topo_id = rows[-2][1]
+ last_stream_id = rows[-2][2]
+ next_token = ThreadsNextBatch(last_topo_id, last_stream_id)
+
+ return thread_ids[:limit], next_token
+
+ return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
+
+ @cached()
+ async def get_thread_id(self, event_id: str) -> str:
"""
- Fetch event metadata for events which related to the same event as the given event.
+ Get the thread ID for an event. This considers multi-level relations,
+ e.g. an annotation to an event which is part of a thread.
+
+ It only searches up the relations tree, i.e. it only searches for events
+ which the given event is related to (and which those events are related
+ to, etc.)
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ get_thread_id(X) considers events B and C as part of thread A.
- If the given event has no relation information, returns an empty dictionary.
+ See also get_thread_id_for_receipts.
Args:
- event_id: The event ID which is targeted by relations.
- relation_types: The relation types to check for mutual relations.
+ event_id: The event ID to fetch the thread ID for.
Returns:
- A dictionary of relation type to:
- A set of tuples of:
- The sender
- The event type
+ The event ID of the root event in the thread, if this event is part
+ of a thread. "main", otherwise.
"""
- rel_type_sql, rel_type_args = make_in_list_sql_clause(
- self.database_engine, "relation_type", relation_types
- )
- sql = f"""
- SELECT DISTINCT relation_type, sender, type FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE relates_to_id = ? AND {rel_type_sql}
+ # Recurse event relations up to the *root* event, then search that chain
+ # of relations for a thread relation. If one is found, the root event is
+ # returned.
+ #
+ # Note that this should only ever find 0 or 1 entries since it is invalid
+ # for an event to have a thread relation to an event which also has a
+ # relation.
+ sql = """
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relates_to_id, relation_type, 0 depth
+ FROM event_relations
+ WHERE event_id = ?
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
+ FROM event_relations e
+ INNER JOIN related_events r ON r.relates_to_id = e.event_id
+ WHERE depth <= 3
+ )
+ SELECT relates_to_id FROM related_events
+ WHERE relation_type = 'm.thread'
+ ORDER BY depth DESC
+ LIMIT 1;
"""
- def _get_event_relations(
- txn: LoggingTransaction,
- ) -> Dict[str, Set[Tuple[str, str]]]:
- txn.execute(sql, [event_id] + rel_type_args)
- result: Dict[str, Set[Tuple[str, str]]] = {
- rel_type: set() for rel_type in relation_types
- }
- for rel_type, sender, type in txn.fetchall():
- result[rel_type].add((sender, type))
- return result
+ def _get_thread_id(txn: LoggingTransaction) -> str:
+ txn.execute(sql, (event_id,))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ # If no thread was found, it is part of the main timeline.
+ return MAIN_TIMELINE
+
+ return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)
+
+ @cached()
+ async def get_thread_id_for_receipts(self, event_id: str) -> str:
+ """
+ Get the thread ID for an event by traversing to the top-most related event
+ and confirming any children events form a thread.
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part
+ of thread A.
+
+ See also get_thread_id.
+
+ Args:
+ event_id: The event ID to fetch the thread ID for.
+
+ Returns:
+ The event ID of the root event in the thread, if this event is part
+ of a thread. "main", otherwise.
+ """
+
+ # Recurse event relations up to the *root* event, then search for any events
+ # related to that root node for a thread relation. If one is found, the
+ # root event is returned.
+ #
+ # Note that there cannot be thread relations in the middle of the chain since
+ # it is invalid for an event to have a thread relation to an event which also
+ # has a relation.
+ sql = """
+ SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE((
+ WITH RECURSIVE related_events AS (
+ SELECT event_id, relates_to_id, relation_type, 0 depth
+ FROM event_relations
+ WHERE event_id = ?
+ UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1
+ FROM event_relations e
+ INNER JOIN related_events r ON r.relates_to_id = e.event_id
+ WHERE depth <= 3
+ )
+ SELECT relates_to_id FROM related_events
+ ORDER BY depth DESC
+ LIMIT 1
+ ), ?) AND relation_type = 'm.thread' LIMIT 1;
+ """
+
+ def _get_related_thread_id(txn: LoggingTransaction) -> str:
+ txn.execute(sql, (event_id, event_id))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ # If no thread was found, it is part of the main timeline.
+ return MAIN_TIMELINE
return await self.db_pool.runInteraction(
- "get_event_relations", _get_event_relations
+ "get_related_thread_id", _get_related_thread_id
)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 7412bce255..e41c99027a 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -207,21 +207,30 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _construct_room_type_where_clause(
self, room_types: Union[List[Union[str, None]], None]
- ) -> Tuple[Union[str, None], List[str]]:
+ ) -> Tuple[Union[str, None], list]:
if not room_types:
return None, []
- else:
- # We use None when we want get rooms without a type
- is_null_clause = ""
- if None in room_types:
- is_null_clause = "OR room_type IS NULL"
- room_types = [value for value in room_types if value is not None]
+ # Since None is used to represent a room without a type, care needs to
+ # be taken into account when constructing the where clause.
+ clauses = []
+ args: list = []
+
+ room_types_set = set(room_types)
+
+ # We use None to represent a room without a type.
+ if None in room_types_set:
+ clauses.append("room_type IS NULL")
+ room_types_set.remove(None)
+
+ # If there are other room types, generate the proper clause.
+ if room_types:
list_clause, args = make_in_list_sql_clause(
- self.database_engine, "room_type", room_types
+ self.database_engine, "room_type", room_types_set
)
+ clauses.append(list_clause)
- return f"({list_clause} {is_null_clause})", args
+ return f"({' OR '.join(clauses)})", args
async def count_public_rooms(
self,
@@ -241,14 +250,6 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
- room_type_clause, args = self._construct_room_type_where_clause(
- search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
- if search_filter
- else None
- )
- room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
- query_args += args
-
if network_tuple:
if network_tuple.appservice_id:
published_sql = """
@@ -268,6 +269,14 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
UNION SELECT room_id from appservice_room_list
"""
+ room_type_clause, args = self._construct_room_type_where_clause(
+ search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+ if search_filter
+ else None
+ )
+ room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+ query_args += args
+
sql = f"""
SELECT
COUNT(*)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 2337289d88..2ed6ad754f 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -666,7 +666,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cached_method_name="get_rooms_for_user",
list_name="user_ids",
)
- async def get_rooms_for_users(
+ async def _get_rooms_for_users(
self, user_ids: Collection[str]
) -> Dict[str, FrozenSet[str]]:
"""A batched version of `get_rooms_for_user`.
@@ -697,6 +697,21 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
+ async def get_rooms_for_users(
+ self, user_ids: Collection[str]
+ ) -> Dict[str, FrozenSet[str]]:
+ """A batched wrapper around `_get_rooms_for_users`, to prevent locking
+ other calls to `get_rooms_for_user` for large user lists.
+ """
+ all_user_rooms: Dict[str, FrozenSet[str]] = {}
+
+ # 250 users is pretty arbitrary but the data can be quite large if users
+ # are in many rooms.
+ for user_ids in batch_iter(user_ids, 250):
+ all_user_rooms.update(await self._get_rooms_for_users(user_ids))
+
+ return all_user_rooms
+
@cached(max_entries=10000)
async def does_pair_of_users_share_a_room(
self, user_id: str, other_user_id: str
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 530f04e149..5baffbfe55 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1024,28 +1024,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"after": {"event_ids": events_after, "token": end_token},
}
- async def get_all_new_events_stream(
- self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
- ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
+ async def get_all_new_event_ids_stream(
+ self,
+ from_id: int,
+ current_id: int,
+ limit: int,
+ ) -> Tuple[int, Dict[str, Optional[int]]]:
"""Get all new events
- Returns all events with from_id < stream_ordering <= current_id.
+ Returns all event ids with from_id < stream_ordering <= current_id.
Args:
from_id: the stream_ordering of the last event we processed
current_id: the stream_ordering of the most recently processed event
limit: the maximum number of events to return
- get_prev_content: whether to fetch previous event content
Returns:
- A tuple of (next_id, events, event_to_received_ts), where `next_id`
+ A tuple of (next_id, event_to_received_ts), where `next_id`
is the next value to pass as `from_id` (it will either be the
stream_ordering of the last returned event, or, if fewer than `limit`
events were found, the `current_id`). The `event_to_received_ts` is
- a dictionary mapping event ID to the event `received_ts`.
+ a dictionary mapping event ID to the event `received_ts`, sorted by ascending
+ stream_ordering.
"""
- def get_all_new_events_stream_txn(
+ def get_all_new_event_ids_stream_txn(
txn: LoggingTransaction,
) -> Tuple[int, Dict[str, Optional[int]]]:
sql = (
@@ -1070,15 +1073,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, event_to_received_ts
upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
- "get_all_new_events_stream", get_all_new_events_stream_txn
+ "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn
)
- events = await self.get_events_as_list(
- event_to_received_ts.keys(),
- get_prev_content=get_prev_content,
- )
-
- return upper_bound, events, event_to_received_ts
+ return upper_bound, event_to_received_ts
async def get_federation_out_pos(self, typ: str) -> int:
if self._need_to_reset_federation_stream_positions:
@@ -1202,8 +1200,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
`to_token`), or `limit` is zero.
"""
- assert int(limit) >= 0
-
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 4a5c947699..19dbf2da7f 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -90,9 +90,9 @@ Changes in SCHEMA_VERSION = 73;
SCHEMA_COMPAT_VERSION = (
- # The groups tables are no longer accessible, so synapses with SCHEMA_VERSION < 72
- # could break.
- 72
+ # The threads_id column must exist for event_push_actions, event_push_summary,
+ # receipts_linearized, and receipts_graph.
+ 73
)
"""Limit on how far the synapse codebase can be rolled back without breaking db compat
diff --git a/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql
new file mode 100644
index 0000000000..0ffde9bbeb
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/06thread_notifications_backfill.sql
@@ -0,0 +1,29 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Forces the background updates from 06thread_notifications.sql to run in the
+-- foreground as code will now require those to be "done".
+
+DELETE FROM background_updates WHERE update_name = 'event_push_backfill_thread_id';
+
+-- Overwrite any null thread_id columns.
+UPDATE event_push_actions_staging SET thread_id = 'main' WHERE thread_id IS NULL;
+UPDATE event_push_actions SET thread_id = 'main' WHERE thread_id IS NULL;
+UPDATE event_push_summary SET thread_id = 'main' WHERE thread_id IS NULL;
+
+-- Do not run the event_push_summary_unique_index job if it is pending; the
+-- thread_id field will be made required.
+DELETE FROM background_updates WHERE update_name = 'event_push_summary_unique_index';
+DROP INDEX IF EXISTS event_push_summary_unique_index;
diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres
new file mode 100644
index 0000000000..33674f8c62
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.postgres
@@ -0,0 +1,19 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- The columns can now be made non-nullable.
+ALTER TABLE event_push_actions_staging ALTER COLUMN thread_id SET NOT NULL;
+ALTER TABLE event_push_actions ALTER COLUMN thread_id SET NOT NULL;
+ALTER TABLE event_push_summary ALTER COLUMN thread_id SET NOT NULL;
diff --git a/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite
new file mode 100644
index 0000000000..5322ad77a4
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/07thread_notifications_not_null.sql.sqlite
@@ -0,0 +1,101 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- SQLite doesn't support modifying columns to an existing table, so it must
+-- be recreated.
+
+-- Create the new tables.
+CREATE TABLE event_push_actions_staging_new (
+ event_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ actions TEXT NOT NULL,
+ notif SMALLINT NOT NULL,
+ highlight SMALLINT NOT NULL,
+ unread SMALLINT,
+ thread_id TEXT NOT NULL,
+ inserted_ts BIGINT
+);
+
+CREATE TABLE event_push_actions_new (
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ profile_tag VARCHAR(32),
+ actions TEXT NOT NULL,
+ topological_ordering BIGINT,
+ stream_ordering BIGINT,
+ notif SMALLINT,
+ highlight SMALLINT,
+ unread SMALLINT,
+ thread_id TEXT NOT NULL,
+ CONSTRAINT event_id_user_id_profile_tag_uniqueness UNIQUE (room_id, event_id, user_id, profile_tag)
+);
+
+CREATE TABLE event_push_summary_new (
+ user_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+ notif_count BIGINT NOT NULL,
+ stream_ordering BIGINT NOT NULL,
+ unread_count BIGINT,
+ last_receipt_stream_ordering BIGINT,
+ thread_id TEXT NOT NULL
+);
+
+-- Swap the indexes.
+DROP INDEX IF EXISTS event_push_actions_staging_id;
+CREATE INDEX event_push_actions_staging_id ON event_push_actions_staging_new(event_id);
+
+DROP INDEX IF EXISTS event_push_actions_room_id_user_id;
+DROP INDEX IF EXISTS event_push_actions_rm_tokens;
+DROP INDEX IF EXISTS event_push_actions_stream_ordering;
+DROP INDEX IF EXISTS event_push_actions_u_highlight;
+DROP INDEX IF EXISTS event_push_actions_highlights_index;
+CREATE INDEX event_push_actions_room_id_user_id on event_push_actions_new(room_id, user_id);
+CREATE INDEX event_push_actions_rm_tokens on event_push_actions_new( user_id, room_id, topological_ordering, stream_ordering );
+CREATE INDEX event_push_actions_stream_ordering on event_push_actions_new( stream_ordering, user_id );
+CREATE INDEX event_push_actions_u_highlight ON event_push_actions_new (user_id, stream_ordering);
+CREATE INDEX event_push_actions_highlights_index ON event_push_actions_new (user_id, room_id, topological_ordering, stream_ordering);
+
+-- Copy the data.
+INSERT INTO event_push_actions_staging_new (event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts)
+ SELECT event_id, user_id, actions, notif, highlight, unread, thread_id, inserted_ts
+ FROM event_push_actions_staging;
+
+INSERT INTO event_push_actions_new (room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id)
+ SELECT room_id, event_id, user_id, profile_tag, actions, topological_ordering, stream_ordering, notif, highlight, unread, thread_id
+ FROM event_push_actions;
+
+INSERT INTO event_push_summary_new (user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id)
+ SELECT user_id, room_id, notif_count, stream_ordering, unread_count, last_receipt_stream_ordering, thread_id
+ FROM event_push_summary;
+
+-- Drop the old tables.
+DROP TABLE event_push_actions_staging;
+DROP TABLE event_push_actions;
+DROP TABLE event_push_summary;
+
+-- Rename the tables.
+ALTER TABLE event_push_actions_staging_new RENAME TO event_push_actions_staging;
+ALTER TABLE event_push_actions_new RENAME TO event_push_actions;
+ALTER TABLE event_push_summary_new RENAME TO event_push_summary;
+
+-- Re-run background updates from 72/02event_push_actions_index.sql and
+-- 72/06thread_notifications.sql.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7307, 'event_push_summary_unique_index2', '{}')
+ ON CONFLICT (update_name) DO NOTHING;
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7307, 'event_push_actions_stream_highlight_index', '{}')
+ ON CONFLICT (update_name) DO NOTHING;
diff --git a/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres
new file mode 100644
index 0000000000..3e0bc9e5eb
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.postgres
@@ -0,0 +1,23 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Drop constraint on (room_id, receipt_type, user_id).
+
+-- Rebuild the unique constraint with the thread_id.
+ALTER TABLE receipts_linearized
+ DROP CONSTRAINT receipts_linearized_uniqueness;
+
+ALTER TABLE receipts_graph
+ DROP CONSTRAINT receipts_graph_uniqueness;
diff --git a/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite
new file mode 100644
index 0000000000..e664889fbc
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/08thread_receipts_non_null.sql.sqlite
@@ -0,0 +1,76 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Drop constraint on (room_id, receipt_type, user_id).
+--
+-- SQLite doesn't support modifying constraints to an existing table, so it must
+-- be recreated.
+
+-- Create the new tables.
+CREATE TABLE receipts_linearized_new (
+ stream_id BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ receipt_type TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ thread_id TEXT,
+ event_stream_ordering BIGINT,
+ data TEXT NOT NULL,
+ CONSTRAINT receipts_linearized_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id)
+);
+
+CREATE TABLE receipts_graph_new (
+ room_id TEXT NOT NULL,
+ receipt_type TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ event_ids TEXT NOT NULL,
+ thread_id TEXT,
+ data TEXT NOT NULL,
+ CONSTRAINT receipts_graph_uniqueness_thread UNIQUE (room_id, receipt_type, user_id, thread_id)
+);
+
+-- Drop the old indexes.
+DROP INDEX IF EXISTS receipts_linearized_id;
+DROP INDEX IF EXISTS receipts_linearized_room_stream;
+DROP INDEX IF EXISTS receipts_linearized_user;
+
+-- Copy the data.
+INSERT INTO receipts_linearized_new (stream_id, room_id, receipt_type, user_id, event_id, data)
+ SELECT stream_id, room_id, receipt_type, user_id, event_id, data
+ FROM receipts_linearized;
+INSERT INTO receipts_graph_new (room_id, receipt_type, user_id, event_ids, data)
+ SELECT room_id, receipt_type, user_id, event_ids, data
+ FROM receipts_graph;
+
+-- Drop the old tables.
+DROP TABLE receipts_linearized;
+DROP TABLE receipts_graph;
+
+-- Rename the tables.
+ALTER TABLE receipts_linearized_new RENAME TO receipts_linearized;
+ALTER TABLE receipts_graph_new RENAME TO receipts_graph;
+
+-- Create the indices.
+CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id );
+CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id );
+CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id );
+
+-- Re-run background updates from 72/08thread_receipts.sql.
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7308, 'receipts_linearized_unique_index', '{}')
+ ON CONFLICT (update_name) DO NOTHING;
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7308, 'receipts_graph_unique_index', '{}')
+ ON CONFLICT (update_name) DO NOTHING;
diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql
new file mode 100644
index 0000000000..aa7c5e9a2e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/73/09threads_table.sql
@@ -0,0 +1,30 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE threads (
+ room_id TEXT NOT NULL,
+ -- The event ID of the root event in the thread.
+ thread_id TEXT NOT NULL,
+ -- The latest event ID and corresponding topo / stream ordering.
+ latest_event_id TEXT NOT NULL,
+ topological_ordering BIGINT NOT NULL,
+ stream_ordering BIGINT NOT NULL,
+ CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id)
+);
+
+CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering);
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (7309, 'threads_backfill', '{}');
diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index 806b671305..2dcd43d0a2 100644
--- a/synapse/streams/__init__.py
+++ b/synapse/streams/__init__.py
@@ -27,7 +27,7 @@ class EventSource(Generic[K, R]):
self,
user: UserID,
from_key: K,
- limit: Optional[int],
+ limit: int,
room_ids: Collection[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index b52723e2b8..6df2de919c 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -35,17 +35,19 @@ class PaginationConfig:
from_token: Optional[StreamToken]
to_token: Optional[StreamToken]
direction: str
- limit: Optional[int]
+ limit: int
@classmethod
async def from_request(
cls,
store: "DataStore",
request: SynapseRequest,
- raise_invalid_params: bool = True,
- default_limit: Optional[int] = None,
+ default_limit: int,
+ default_dir: str = "f",
) -> "PaginationConfig":
- direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+ direction = parse_string(
+ request, "dir", default=default_dir, allowed_values=["f", "b"]
+ )
from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to")
@@ -67,12 +69,10 @@ class PaginationConfig:
raise SynapseError(400, "'to' parameter is invalid")
limit = parse_integer(request, "limit", default=default_limit)
+ if limit < 0:
+ raise SynapseError(400, "Limit must be 0 or above")
- if limit:
- if limit < 0:
- raise SynapseError(400, "Limit must be 0 or above")
-
- limit = min(int(limit), MAX_LIMIT)
+ limit = min(limit, MAX_LIMIT)
try:
return PaginationConfig(from_tok, to_tok, direction, limit)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index a90f08dd4c..7be9d5f113 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -15,7 +15,7 @@
import json
import logging
import typing
-from typing import Any, Callable, Dict, Generator, Optional
+from typing import Any, Callable, Dict, Generator, Optional, Sequence
import attr
from frozendict import frozendict
@@ -193,3 +193,15 @@ def log_failure(
# Version string with git info. Computed here once so that we don't invoke git multiple
# times.
SYNAPSE_VERSION = get_distribution_version_string("matrix-synapse", __file__)
+
+
+class ExceptionBundle(Exception):
+ # A poor stand-in for something like Python 3.11's ExceptionGroup.
+ # (A backport called `exceptiongroup` exists but seems overkill: we just want a
+ # container type here.)
+ def __init__(self, message: str, exceptions: Sequence[Exception]):
+ parts = [message]
+ for e in exceptions:
+ parts.append(str(e))
+ super().__init__("\n - ".join(parts))
+ self.exceptions = exceptions
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 27a363d7e5..4961fe9313 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -86,7 +86,7 @@ def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
ValueError if the server name could not be parsed.
"""
try:
- if server_name[-1] == "]":
+ if server_name and server_name[-1] == "]":
# ipv6 literal, hopefully
return server_name, None
@@ -123,7 +123,7 @@ def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
- if host[0] == "[":
+ if host and host[0] == "[":
if host[-1] != "]":
raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
|