diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 2a2c9e6f13..bb33345be6 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,10 +15,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
from typing import List
import jsonschema
-from canonicaljson import json
from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index b6c9085670..7d309b1bb0 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -14,13 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
+import json
import logging
import os
import sys
import tempfile
-from canonicaljson import json
-
from twisted.internet import defer, task
import synapse
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index c96e6ef62a..13d6f6a3ea 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -17,6 +17,7 @@ import logging
import logging.config
import os
import sys
+import threading
from string import Template
import yaml
@@ -25,6 +26,7 @@ from twisted.logger import (
ILogObserver,
LogBeginner,
STDLibLogObserver,
+ eventAsText,
globalLogBeginner,
)
@@ -216,8 +218,9 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
# system.
observer = STDLibLogObserver()
- def _log(event):
+ threadlocal = threading.local()
+ def _log(event):
if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return
@@ -228,7 +231,25 @@ def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
if event["log_text"].startswith("Timing out client"):
return
- return observer(event)
+ # this is a workaround to make sure we don't get stack overflows when the
+ # logging system raises an error which is written to stderr which is redirected
+ # to the logging system, etc.
+ if getattr(threadlocal, "active", False):
+ # write the text of the event, if any, to the *real* stderr (which may
+ # be redirected to /dev/null, but there's not much we can do)
+ try:
+ event_text = eventAsText(event)
+ print("logging during logging: %s" % event_text, file=sys.__stderr__)
+ except Exception:
+ # gah.
+ pass
+ return
+
+ try:
+ threadlocal.active = True
+ return observer(event)
+ finally:
+ threadlocal.active = False
logBeginner.beginLoggingTo([_log], redirectStandardIO=not config.no_redirect_stdio)
if not config.no_redirect_stdio:
diff --git a/synapse/config/server.py b/synapse/config/server.py
index e85c6a0840..532b910470 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,7 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Dict, Iterable, List, Optional, Set
import attr
import yaml
@@ -542,6 +542,19 @@ class ServerConfig(Config):
users_new_default_push_rules
) # type: set
+ # Whitelist of domain names that given next_link parameters must have
+ next_link_domain_whitelist = config.get(
+ "next_link_domain_whitelist"
+ ) # type: Optional[List[str]]
+
+ self.next_link_domain_whitelist = None # type: Optional[Set[str]]
+ if next_link_domain_whitelist is not None:
+ if not isinstance(next_link_domain_whitelist, list):
+ raise ConfigError("'next_link_domain_whitelist' must be a list")
+
+ # Turn the list into a set to improve lookup speed.
+ self.next_link_domain_whitelist = set(next_link_domain_whitelist)
+
def has_tls_listener(self) -> bool:
return any(listener.tls for listener in self.listeners)
@@ -1014,6 +1027,24 @@ class ServerConfig(Config):
# act as if no error happened and return a fake session ID ('sid') to clients.
#
#request_token_inhibit_3pid_errors: true
+
+ # A list of domains that the domain portion of 'next_link' parameters
+ # must match.
+ #
+ # This parameter is optionally provided by clients while requesting
+ # validation of an email or phone number, and maps to a link that
+ # users will be automatically redirected to after validation
+ # succeeds. Clients can make use this parameter to aid the validation
+ # process.
+ #
+ # The whitelist is applied whether the homeserver or an
+ # identity server is handling validation.
+ #
+ # The default value is no whitelist functionality; all domains are
+ # allowed. Setting this value to an empty list will instead disallow
+ # all domains.
+ #
+ #next_link_domain_whitelist: ["matrix.org"]
"""
% locals()
)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 552519e82c..41a726878d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -209,7 +209,7 @@ class FederationSender:
logger.debug("Sending %s to %r", event, destinations)
if destinations:
- self._send_pdu(event, destinations)
+ await self._send_pdu(event, destinations)
now = self.clock.time_msec()
ts = await self.store.get_received_ts(event.event_id)
@@ -265,7 +265,7 @@ class FederationSender:
finally:
self._is_processing = False
- def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
+ async def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
@@ -280,6 +280,13 @@ class FederationSender:
sent_pdus_destination_dist_total.inc(len(destinations))
sent_pdus_destination_dist_count.inc()
+ # track the fact that we have a PDU for these destinations,
+ # to allow us to perform catch-up later on if the remote is unreachable
+ # for a while.
+ await self.store.store_destination_rooms_entries(
+ destinations, pdu.room_id, pdu.internal_metadata.stream_ordering,
+ )
+
for destination in destinations:
self._get_per_destination_queue(destination).send_pdu(pdu)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index defc228c23..9f0852b4a2 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -325,6 +325,17 @@ class PerDestinationQueue:
self._last_device_stream_id = device_stream_id
self._last_device_list_stream_id = dev_list_id
+
+ if pending_pdus:
+ # we sent some PDUs and it was successful, so update our
+ # last_successful_stream_ordering in the destinations table.
+ final_pdu = pending_pdus[-1]
+ last_successful_stream_ordering = (
+ final_pdu.internal_metadata.stream_ordering
+ )
+ await self._store.set_destination_last_successful_stream_ordering(
+ self._destination, last_successful_stream_ordering
+ )
else:
break
except NotRetryingDestination as e:
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index b05e32f457..fdce54c5c3 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -39,10 +39,6 @@ class EventStreamHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super(EventStreamHandler, self).__init__(hs)
- self.distributor = hs.get_distributor()
- self.distributor.declare("started_user_eventstream")
- self.distributor.declare("stopped_user_eventstream")
-
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 43f2986f89..be9b0701a0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -69,7 +69,6 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
ReplicationStoreRoomOnInviteRestServlet,
)
-from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import (
@@ -80,7 +79,6 @@ from synapse.types import (
get_domain_from_id,
)
from synapse.util.async_helpers import Linearizer, concurrently_execute
-from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server
@@ -141,9 +139,6 @@ class FederationHandler(BaseHandler):
self._replication = hs.get_replication_data_handler()
self._send_events = ReplicationFederationSendEventsRestServlet.make_client(hs)
- self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
- hs
- )
self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
hs
)
@@ -704,31 +699,10 @@ class FederationHandler(BaseHandler):
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
try:
- context = await self._handle_new_event(origin, event, state=state)
+ await self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
- if event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- # Only fire user_joined_room if the user has acutally
- # joined the room. Don't bother if the user is just
- # changing their profile info.
- newly_joined = True
-
- prev_state_ids = await context.get_prev_state_ids()
-
- prev_state_id = prev_state_ids.get((event.type, event.state_key))
- if prev_state_id:
- prev_state = await self.store.get_event(
- prev_state_id, allow_none=True
- )
- if prev_state and prev_state.membership == Membership.JOIN:
- newly_joined = False
-
- if newly_joined:
- user = UserID.from_string(event.state_key)
- await self.user_joined_room(user, room_id)
-
# For encrypted messages we check that we know about the sending device,
# if we don't then we mark the device cache for that user as stale.
if event.type == EventTypes.Encrypted:
@@ -1550,11 +1524,6 @@ class FederationHandler(BaseHandler):
event.signatures,
)
- if event.type == EventTypes.Member:
- if event.content["membership"] == Membership.JOIN:
- user = UserID.from_string(event.state_key)
- await self.user_joined_room(user, event.room_id)
-
prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
@@ -2970,7 +2939,7 @@ class FederationHandler(BaseHandler):
event, event_stream_id, max_stream_id, extra_users=extra_users
)
- await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
+ await self.pusher_pool.on_new_notifications(max_stream_id)
async def _clean_room_for_join(self, room_id: str) -> None:
"""Called to clean up any data in DB for a given room, ready for the
@@ -2984,16 +2953,6 @@ class FederationHandler(BaseHandler):
else:
await self.store.clean_room_for_join(room_id)
- async def user_joined_room(self, user: UserID, room_id: str) -> None:
- """Called when a new user has joined the room
- """
- if self.config.worker_app:
- await self._notify_user_membership_change(
- room_id=room_id, user_id=user.to_string(), change="joined"
- )
- else:
- user_joined_room(self.distributor, user, room_id)
-
async def get_room_complexity(
self, remote_room_hosts: List[str], room_id: str
) -> Optional[dict]:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index d5ddc583ad..ddb8f0712b 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -116,14 +116,13 @@ class InitialSyncHandler(BaseHandler):
now_token = self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
- pagination_config = PaginationConfig(from_token=now_token)
- presence, _ = await presence_stream.get_pagination_rows(
- user, pagination_config.get_source_config("presence"), None
+ presence, _ = await presence_stream.get_new_events(
+ user, from_key=None, include_offline=False
)
- receipt_stream = self.hs.get_event_sources().sources["receipt"]
- receipt, _ = await receipt_stream.get_pagination_rows(
- user, pagination_config.get_source_config("receipt"), None
+ joined_rooms = [r.room_id for r in room_list if r.membership == Membership.JOIN]
+ receipt = await self.store.get_linearized_receipts_for_rooms(
+ joined_rooms, to_key=int(now_token.receipt_key),
)
tags_by_room = await self.store.get_tags_for_user(user_id)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 8a7b4916cd..d1556659e3 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1145,7 +1145,7 @@ class EventCreationHandler:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
- await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
+ await self.pusher_pool.on_new_notifications(max_stream_id)
def _notify():
try:
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 34ed0e2921..ec17d3d888 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -335,20 +335,16 @@ class PaginationHandler:
user_id = requester.user.to_string()
if pagin_config.from_token:
- room_token = pagin_config.from_token.room_key
+ from_token = pagin_config.from_token
else:
- pagin_config.from_token = (
- self.hs.get_event_sources().get_current_token_for_pagination()
- )
- room_token = pagin_config.from_token.room_key
-
- room_token = RoomStreamToken.parse(room_token)
+ from_token = self.hs.get_event_sources().get_current_token_for_pagination()
- pagin_config.from_token = pagin_config.from_token.copy_and_replace(
- "room_key", str(room_token)
- )
+ if pagin_config.limit is None:
+ # This shouldn't happen as we've set a default limit before this
+ # gets called.
+ raise Exception("limit not set")
- source_config = pagin_config.get_source_config("room")
+ room_token = RoomStreamToken.parse(from_token.room_key)
with await self.pagination_lock.read(room_id):
(
@@ -358,7 +354,7 @@ class PaginationHandler:
room_id, user_id, allow_departed_users=True
)
- if source_config.direction == "b":
+ if pagin_config.direction == "b":
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
@@ -377,26 +373,35 @@ class PaginationHandler:
# case "JOIN" would have been returned.
assert member_event_id
- leave_token = await self.store.get_topological_token_for_event(
+ leave_token_str = await self.store.get_topological_token_for_event(
member_event_id
)
- if RoomStreamToken.parse(leave_token).topological < max_topo:
- source_config.from_key = str(leave_token)
+ leave_token = RoomStreamToken.parse(leave_token_str)
+ assert leave_token.topological is not None
+
+ if leave_token.topological < max_topo:
+ from_token = from_token.copy_and_replace(
+ "room_key", leave_token_str
+ )
await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
+ to_room_key = None
+ if pagin_config.to_token:
+ to_room_key = pagin_config.to_token.room_key
+
events, next_key = await self.store.paginate_room_events(
room_id=room_id,
- from_key=source_config.from_key,
- to_key=source_config.to_key,
- direction=source_config.direction,
- limit=source_config.limit,
+ from_key=from_token.room_key,
+ to_key=to_room_key,
+ direction=pagin_config.direction,
+ limit=pagin_config.limit,
event_filter=event_filter,
)
- next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
+ next_token = from_token.copy_and_replace("room_key", next_key)
if events:
if event_filter:
@@ -409,7 +414,7 @@ class PaginationHandler:
if not events:
return {
"chunk": [],
- "start": pagin_config.from_token.to_string(),
+ "start": from_token.to_string(),
"end": next_token.to_string(),
}
@@ -438,7 +443,7 @@ class PaginationHandler:
events, time_now, as_client_event=as_client_event
)
),
- "start": pagin_config.from_token.to_string(),
+ "start": from_token.to_string(),
"end": next_token.to_string(),
}
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 91a3aec1cc..1000ac95ff 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1108,9 +1108,6 @@ class PresenceEventSource:
def get_current_key(self):
return self.store.get_current_presence_token()
- async def get_pagination_rows(self, user, pagination_config, key):
- return await self.get_new_events(user, from_key=None, include_offline=False)
-
@cached(num_args=2, cache_context=True)
async def _get_interested_in(self, user, explicit_room_id, cache_context):
"""Returns the set of users that the given user should see presence
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 2cc6c2eb68..bdd8e52edd 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -142,18 +142,3 @@ class ReceiptEventSource:
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
-
- async def get_pagination_rows(self, user, config, key):
- to_key = int(config.from_key)
-
- if config.to_key:
- from_key = int(config.to_key)
- else:
- from_key = None
-
- room_ids = await self.store.get_rooms_for_user(user.to_string())
- events = await self.store.get_linearized_receipts_for_rooms(
- room_ids, from_key=from_key, to_key=to_key
- )
-
- return (events, to_key)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 32b7e323fa..100f335b80 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -40,7 +40,7 @@ from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
from ._base import BaseHandler
@@ -149,17 +149,6 @@ class RoomMemberHandler:
raise NotImplementedError()
@abc.abstractmethod
- async def _user_joined_room(self, target: UserID, room_id: str) -> None:
- """Notifies distributor on master process that the user has joined the
- room.
-
- Args:
- target
- room_id
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Notifies distributor on master process that the user has left the
room.
@@ -221,7 +210,6 @@ class RoomMemberHandler:
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
- newly_joined = False
if event.membership == Membership.JOIN:
newly_joined = True
if prev_member_event_id:
@@ -246,12 +234,7 @@ class RoomMemberHandler:
requester, event, context, extra_users=[target], ratelimit=ratelimit,
)
- if event.membership == Membership.JOIN and newly_joined:
- # Only fire user_joined_room if the user has actually joined the
- # room. Don't bother if the user is just changing their profile
- # info.
- await self._user_joined_room(target, room_id)
- elif event.membership == Membership.LEAVE:
+ if event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
@@ -726,17 +709,7 @@ class RoomMemberHandler:
(EventTypes.Member, event.state_key), None
)
- if event.membership == Membership.JOIN:
- # Only fire user_joined_room if the user has actually joined the
- # room. Don't bother if the user is just changing their profile
- # info.
- newly_joined = True
- if prev_member_event_id:
- prev_member_event = await self.store.get_event(prev_member_event_id)
- newly_joined = prev_member_event.membership != Membership.JOIN
- if newly_joined:
- await self._user_joined_room(target_user, room_id)
- elif event.membership == Membership.LEAVE:
+ if event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
@@ -1002,10 +975,9 @@ class RoomMemberHandler:
class RoomMemberMasterHandler(RoomMemberHandler):
def __init__(self, hs):
- super(RoomMemberMasterHandler, self).__init__(hs)
+ super().__init__(hs)
self.distributor = hs.get_distributor()
- self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
async def _is_remote_room_too_complex(
@@ -1085,7 +1057,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
event_id, stream_id = await self.federation_handler.do_invite_join(
remote_room_hosts, room_id, user.to_string(), content
)
- await self._user_joined_room(user, room_id)
# Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before.
@@ -1228,11 +1199,6 @@ class RoomMemberMasterHandler(RoomMemberHandler):
)
return event.event_id, stream_id
- async def _user_joined_room(self, target: UserID, room_id: str) -> None:
- """Implements RoomMemberHandler._user_joined_room
- """
- user_joined_room(self.distributor, target, room_id)
-
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 897338fd54..e7f34737c6 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -57,8 +57,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
content=content,
)
- await self._user_joined_room(user, room_id)
-
return ret["event_id"], ret["stream_id"]
async def remote_reject_invite(
@@ -81,13 +79,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
)
return ret["event_id"], ret["stream_id"]
- async def _user_joined_room(self, target: UserID, room_id: str) -> None:
- """Implements RoomMemberHandler._user_joined_room
- """
- await self._notify_change_client(
- user_id=target.to_string(), room_id=room_id, change="joined"
- )
-
async def _user_left_room(self, target: UserID, room_id: str) -> None:
"""Implements RoomMemberHandler._user_left_room
"""
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e2ddb628ff..cc47e8b62c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1310,12 +1310,11 @@ class SyncHandler:
presence_source = self.event_sources.sources["presence"]
since_token = sync_result_builder.since_token
+ presence_key = None
+ include_offline = False
if since_token and not sync_result_builder.full_state:
presence_key = since_token.presence_key
include_offline = True
- else:
- presence_key = None
- include_offline = False
presence, presence_key = await presence_source.get_new_events(
user=user,
diff --git a/synapse/logging/utils.py b/synapse/logging/utils.py
index fea774e2e5..becf66dd86 100644
--- a/synapse/logging/utils.py
+++ b/synapse/logging/utils.py
@@ -29,11 +29,11 @@ def _log_debug_as_f(f, msg, msg_args):
lineno = f.__code__.co_firstlineno
pathname = f.__code__.co_filename
- record = logging.LogRecord(
+ record = logger.makeRecord(
name=name,
level=logging.DEBUG,
- pathname=pathname,
- lineno=lineno,
+ fn=pathname,
+ lno=lineno,
msg=msg,
args=msg_args,
exc_info=None,
diff --git a/synapse/notifier.py b/synapse/notifier.py
index b7f4041306..71f2370874 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -432,8 +432,9 @@ class Notifier:
If explicit_room_id is set, that room will be polled for events only if
it is world readable or the user has joined the room.
"""
- from_token = pagination_config.from_token
- if not from_token:
+ if pagination_config.from_token:
+ from_token = pagination_config.from_token
+ else:
from_token = self.event_sources.get_current_token()
limit = pagination_config.limit
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index b7ea4438e0..28bd8ab748 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -91,7 +91,7 @@ class EmailPusher:
pass
self.timed_call = None
- def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
+ def on_new_notifications(self, max_stream_ordering):
if self.max_stream_ordering:
self.max_stream_ordering = max(
max_stream_ordering, self.max_stream_ordering
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index f21fa9b659..26706bf3e1 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -114,7 +114,7 @@ class HttpPusher:
if should_check_for_notifs:
self._start_processing()
- def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
+ def on_new_notifications(self, max_stream_ordering):
self.max_stream_ordering = max(
max_stream_ordering, self.max_stream_ordering or 0
)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3c3262a88c..fa8473bf8d 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -64,6 +64,12 @@ class PusherPool:
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
+ # Record the last stream ID that we were poked about so we can get
+ # changes since then. We set this to the current max stream ID on
+ # startup as every individual pusher will have checked for changes on
+ # startup.
+ self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
+
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
@@ -178,20 +184,27 @@ class PusherPool:
)
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
- async def on_new_notifications(self, min_stream_id, max_stream_id):
+ async def on_new_notifications(self, max_stream_id):
if not self.pushers:
# nothing to do here.
return
+ if max_stream_id < self._last_room_stream_id_seen:
+ # Nothing to do
+ return
+
+ prev_stream_id = self._last_room_stream_id_seen
+ self._last_room_stream_id_seen = max_stream_id
+
try:
users_affected = await self.store.get_push_action_users_in_range(
- min_stream_id, max_stream_id
+ prev_stream_id, max_stream_id
)
for u in users_affected:
if u in self.pushers:
for p in self.pushers[u].values():
- p.on_new_notifications(min_stream_id, max_stream_id)
+ p.on_new_notifications(max_stream_id)
except Exception:
logger.exception("Exception in pusher on_new_notifications")
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 2d995ec456..ff0c67228b 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -43,7 +43,7 @@ REQUIREMENTS = [
"jsonschema>=2.5.1",
"frozendict>=1",
"unpaddedbase64>=1.1.0",
- "canonicaljson>=1.3.0",
+ "canonicaljson>=1.4.0",
# we use the type definitions added in signedjson 1.1.
"signedjson>=1.1.0",
"pynacl>=1.2.1",
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 741329ab5f..08095fdf7d 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Optional
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
-from synapse.util.distributor import user_joined_room, user_left_room
+from synapse.util.distributor import user_left_room
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -181,9 +181,9 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
Args:
room_id (str)
user_id (str)
- change (str): Either "joined" or "left"
+ change (str): "left"
"""
- assert change in ("joined", "left")
+ assert change == "left"
return {}
@@ -192,9 +192,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
user = UserID.from_string(user_id)
- if change == "joined":
- user_joined_room(self.distributor, user, room_id)
- elif change == "left":
+ if change == "left":
user_left_room(self.distributor, user, room_id)
else:
raise Exception("Unrecognized change: %r", change)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index d6ecf5b327..ccd3147dfd 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -154,7 +154,8 @@ class ReplicationDataHandler:
max_token = self.store.get_room_max_stream_ordering()
self.notifier.on_new_room_event(event, token, max_token, extra_users)
- await self.pusher_pool.on_new_notifications(token, token)
+ max_token = self.store.get_room_max_stream_ordering()
+ await self.pusher_pool.on_new_notifications(max_token)
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e781a3bcf4..ddf8ed5e9c 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -163,6 +163,18 @@ class PushRuleRestServlet(RestServlet):
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(self, user_id, spec, val):
+ if spec["attr"] not in ("enabled", "actions"):
+ # for the sake of potential future expansion, shouldn't report
+ # 404 in the case of an unknown request so check it corresponds to
+ # a known attribute first.
+ raise UnrecognizedRequestError()
+
+ namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+ rule_id = spec["rule_id"]
+ is_default_rule = rule_id.startswith(".")
+ if is_default_rule:
+ if namespaced_rule_id not in BASE_RULE_IDS:
+ raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,))
if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
@@ -171,9 +183,8 @@ class PushRuleRestServlet(RestServlet):
# This should *actually* take a dict, but many clients pass
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
- namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
return await self.store.set_push_rule_enabled(
- user_id, namespaced_rule_id, val
+ user_id, namespaced_rule_id, val, is_default_rule
)
elif spec["attr"] == "actions":
actions = val.get("actions")
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 3481477731..455051ac46 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -17,6 +17,11 @@
import logging
import random
from http import HTTPStatus
+from typing import TYPE_CHECKING
+from urllib.parse import urlparse
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
from synapse.api.constants import LoginType
from synapse.api.errors import (
@@ -98,6 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -446,6 +454,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
existing_user_id = await self.store.get_user_id_by_threepid("email", email)
if existing_user_id is not None:
@@ -517,6 +528,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ # Raise if the provided next_link value isn't valid
+ assert_valid_next_link(self.hs, next_link)
+
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
@@ -603,15 +617,10 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Perform a 302 redirect if next_link is set
if next_link:
- if next_link.startswith("file:///"):
- logger.warning(
- "Not redirecting to next_link as it is a local file: address"
- )
- else:
- request.setResponseCode(302)
- request.setHeader("Location", next_link)
- finish_request(request)
- return None
+ request.setResponseCode(302)
+ request.setHeader("Location", next_link)
+ finish_request(request)
+ return None
# Otherwise show the success template
html = self.config.email_add_threepid_template_success_html_content
@@ -875,6 +884,45 @@ class ThreepidDeleteRestServlet(RestServlet):
return 200, {"id_server_unbind_result": id_server_unbind_result}
+def assert_valid_next_link(hs: "HomeServer", next_link: str):
+ """
+ Raises a SynapseError if a given next_link value is invalid
+
+ next_link is valid if the scheme is http(s) and the next_link.domain_whitelist config
+ option is either empty or contains a domain that matches the one in the given next_link
+
+ Args:
+ hs: The homeserver object
+ next_link: The next_link value given by the client
+
+ Raises:
+ SynapseError: If the next_link is invalid
+ """
+ valid = True
+
+ # Parse the contents of the URL
+ next_link_parsed = urlparse(next_link)
+
+ # Scheme must not point to the local drive
+ if next_link_parsed.scheme == "file":
+ valid = False
+
+ # If the domain whitelist is set, the domain must be in it
+ if (
+ valid
+ and hs.config.next_link_domain_whitelist is not None
+ and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist
+ ):
+ valid = False
+
+ if not valid:
+ raise SynapseError(
+ 400,
+ "'next_link' domain not included in whitelist, or not http(s)",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+
class WhoamiRestServlet(RestServlet):
PATTERNS = client_patterns("/account/whoami$")
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index d2826374a7..7447eeaebe 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -80,7 +80,7 @@ class MediaFilePaths:
self, server_name, file_id, width, height, content_type, method
):
top_level_type, sub_type = content_type.split("/")
- file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+ file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
"remote_thumbnail",
server_name,
@@ -92,6 +92,23 @@ class MediaFilePaths:
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
+ # Legacy path that was used to store thumbnails previously.
+ # Should be removed after some time, when most of the thumbnails are stored
+ # using the new path.
+ def remote_media_thumbnail_rel_legacy(
+ self, server_name, file_id, width, height, content_type
+ ):
+ top_level_type, sub_type = content_type.split("/")
+ file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
+ return os.path.join(
+ "remote_thumbnail",
+ server_name,
+ file_id[0:2],
+ file_id[2:4],
+ file_id[4:],
+ file_name,
+ )
+
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
self.base_path,
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9a1b7779f7..69f353d46f 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -53,7 +53,7 @@ from .media_storage import MediaStorage
from .preview_url_resource import PreviewUrlResource
from .storage_provider import StorageProviderWrapper
from .thumbnail_resource import ThumbnailResource
-from .thumbnailer import Thumbnailer
+from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource
logger = logging.getLogger(__name__)
@@ -460,13 +460,30 @@ class MediaRepository:
return t_byte_source
async def generate_local_exact_thumbnail(
- self, media_id, t_width, t_height, t_method, t_type, url_cache
- ):
+ self,
+ media_id: str,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ url_cache: str,
+ ) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s",
+ media_id,
+ t_method,
+ t_type,
+ e,
+ )
+ return None
+
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
@@ -506,14 +523,36 @@ class MediaRepository:
return output_path
+ # Could not generate thumbnail.
+ return None
+
async def generate_remote_exact_thumbnail(
- self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
- ):
+ self,
+ server_name: str,
+ file_id: str,
+ media_id: str,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ ) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=False)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s",
+ media_id,
+ server_name,
+ t_method,
+ t_type,
+ e,
+ )
+ return None
+
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
@@ -559,6 +598,9 @@ class MediaRepository:
return output_path
+ # Could not generate thumbnail.
+ return None
+
async def _generate_thumbnails(
self,
server_name: Optional[str],
@@ -590,7 +632,18 @@ class MediaRepository:
FileInfo(server_name, file_id, url_cache=url_cache)
)
- thumbnailer = Thumbnailer(input_path)
+ try:
+ thumbnailer = Thumbnailer(input_path)
+ except ThumbnailError as e:
+ logger.warning(
+ "Unable to generate thumbnails for remote media %s from %s using a method of %s and type of %s: %s",
+ media_id,
+ server_name,
+ media_type,
+ e,
+ )
+ return None
+
m_width = thumbnailer.width
m_height = thumbnailer.height
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 3a352b5631..5681677fc9 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -147,6 +147,20 @@ class MediaStorage:
if os.path.exists(local_path):
return FileResponder(open(local_path, "rb"))
+ # Fallback for paths without method names
+ # Should be removed in the future
+ if file_info.thumbnail and file_info.server_name:
+ legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ )
+ legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+ if os.path.exists(legacy_local_path):
+ return FileResponder(open(legacy_local_path, "rb"))
+
for provider in self.storage_providers:
res = await provider.fetch(path, file_info) # type: Any
if res:
@@ -170,6 +184,20 @@ class MediaStorage:
if os.path.exists(local_path):
return local_path
+ # Fallback for paths without method names
+ # Should be removed in the future
+ if file_info.thumbnail and file_info.server_name:
+ legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ )
+ legacy_local_path = os.path.join(self.local_media_directory, legacy_path)
+ if os.path.exists(legacy_local_path):
+ return legacy_local_path
+
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
os.makedirs(dirname)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a83535b97b..30421b663a 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,6 +16,7 @@
import logging
+from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
@@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
- respond_404(request)
+ raise SynapseError(400, "Failed to generate thumbnail.")
async def _select_or_generate_remote_thumbnail(
self,
@@ -235,7 +236,7 @@ class ThumbnailResource(DirectServeJsonResource):
await respond_with_file(request, desired_type, file_path)
else:
logger.warning("Failed to generate thumbnail")
- respond_404(request)
+ raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
self, request, server_name, media_id, width, height, method, m_type
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index d681bf7bf0..457ad6031c 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -15,7 +15,7 @@
import logging
from io import BytesIO
-from PIL import Image as Image
+from PIL import Image
logger = logging.getLogger(__name__)
@@ -31,12 +31,22 @@ EXIF_TRANSPOSE_MAPPINGS = {
}
+class ThumbnailError(Exception):
+ """An error occurred generating a thumbnail."""
+
+
class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
def __init__(self, input_path):
- self.image = Image.open(input_path)
+ try:
+ self.image = Image.open(input_path)
+ except OSError as e:
+ # If an error occurs opening the image, a thumbnail won't be able to
+ # be generated.
+ raise ThumbnailError from e
+
self.width, self.height = self.image.size
self.transpose_method = None
try:
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ed8a9bffb1..79ec8f119d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -952,7 +952,7 @@ class DatabasePool:
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[str]],
+ value_values: Iterable[Iterable[Any]],
) -> None:
"""
Upsert, many times.
@@ -981,7 +981,7 @@ class DatabasePool:
key_names: Iterable[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[str]],
+ value_values: Iterable[Iterable[Any]],
) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index add4e3ea0e..306fc6947c 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore):
}
async def get_users_whose_devices_changed(
- self, from_key: str, user_ids: Iterable[str]
+ self, from_key: int, user_ids: Iterable[str]
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
@@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
The set of user_ids whose devices have changed since `from_key`
"""
- from_key = int(from_key)
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
@@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
async def get_users_whose_signatures_changed(
- self, user_id: str, from_key: str
+ self, user_id: str, from_key: int
) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
@@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
A set of user IDs with updated signatures.
"""
- from_key = int(from_key)
+
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
sql = """
SELECT DISTINCT user_ids FROM user_signature_stream
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 86557d5512..1d76c761a6 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -17,6 +17,10 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
+BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
+ "media_repository_drop_index_wo_method"
+)
+
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
@@ -32,6 +36,59 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
where_clause="url_cache IS NOT NULL",
)
+ # The following the updates add the method to the unique constraint of
+ # the thumbnail databases. That fixes an issue, where thumbnails of the
+ # same resolution, but different methods could overwrite one another.
+ # This can happen with custom thumbnail configs or with dynamic thumbnailing.
+ self.db_pool.updates.register_background_index_update(
+ update_name="local_media_repository_thumbnails_method_idx",
+ index_name="local_media_repository_thumbn_media_id_width_height_method_key",
+ table="local_media_repository_thumbnails",
+ columns=[
+ "media_id",
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_type",
+ "thumbnail_method",
+ ],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_index_update(
+ update_name="remote_media_repository_thumbnails_method_idx",
+ index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key",
+ table="remote_media_cache_thumbnails",
+ columns=[
+ "media_origin",
+ "media_id",
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_type",
+ "thumbnail_method",
+ ],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD,
+ self._drop_media_index_without_method,
+ )
+
+ async def _drop_media_index_without_method(self, progress, batch_size):
+ def f(txn):
+ txn.execute(
+ "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+ )
+ txn.execute(
+ "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+ )
+
+ await self.db_pool.runInteraction("drop_media_indices_without_method", f)
+ await self.db_pool.updates._end_background_update(
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
+ )
+ return 1
+
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ea833829ae..d7a03cbf7d 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,6 +69,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# room_depth
# state_groups
# state_groups_state
+ # destination_rooms
# we will build a temporary table listing the events so that we don't
# have to keep shovelling the list back and forth across the
@@ -336,6 +337,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# and finally, the tables with an index on room_id (or no useful index)
for table in (
"current_state_events",
+ "destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
"event_json",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 0de802a86b..9790a31998 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -13,11 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import abc
import logging
from typing import List, Tuple, Union
+from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -27,6 +27,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
@@ -540,6 +541,25 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
+ # ensure we have a push_rules_enable row
+ # enabledness defaults to true
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ INSERT INTO push_rules_enable (id, user_name, rule_id, enabled)
+ VALUES (?, ?, ?, ?)
+ ON CONFLICT DO NOTHING
+ """
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = """
+ INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled)
+ VALUES (?, ?, ?, ?)
+ """
+ else:
+ raise RuntimeError("Unknown database engine")
+
+ new_enable_id = self._push_rules_enable_id_gen.get_next()
+ txn.execute(sql, (new_enable_id, user_id, rule_id, 1))
+
async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
Delete a push rule. Args specify the row to be deleted and can be
@@ -552,6 +572,12 @@ class PushRuleStore(PushRulesWorkerStore):
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+ # we don't use simple_delete_one_txn because that would fail if the
+ # user did not have a push_rule_enable row.
+ self.db_pool.simple_delete_txn(
+ txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}
+ )
+
self.db_pool.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
@@ -570,10 +596,29 @@ class PushRuleStore(PushRulesWorkerStore):
event_stream_ordering,
)
- async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
+ async def set_push_rule_enabled(
+ self, user_id: str, rule_id: str, enabled: bool, is_default_rule: bool
+ ) -> None:
+ """
+ Sets the `enabled` state of a push rule.
+
+ Args:
+ user_id: the user ID of the user who wishes to enable/disable the rule
+ e.g. '@tina:example.org'
+ rule_id: the full rule ID of the rule to be enabled/disabled
+ e.g. 'global/override/.m.rule.roomnotif'
+ or 'global/override/myCustomRule'
+ enabled: True if the rule is to be enabled, False if it is to be
+ disabled
+ is_default_rule: True if and only if this is a server-default rule.
+ This skips the check for existence (as only user-created rules
+ are always stored in the database `push_rules` table).
+
+ Raises:
+ NotFoundError if the rule does not exist.
+ """
with await self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
-
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
@@ -582,12 +627,47 @@ class PushRuleStore(PushRulesWorkerStore):
user_id,
rule_id,
enabled,
+ is_default_rule,
)
def _set_push_rule_enabled_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ enabled,
+ is_default_rule,
):
new_id = self._push_rules_enable_id_gen.get_next()
+
+ if not is_default_rule:
+ # first check it exists; we need to lock for key share so that a
+ # transaction that deletes the push rule will conflict with this one.
+ # We also need a push_rule_enable row to exist for every push_rules
+ # row, otherwise it is possible to simultaneously delete a push rule
+ # (that has no _enable row) and enable it, resulting in a dangling
+ # _enable row. To solve this: we either need to use SERIALISABLE or
+ # ensure we always have a push_rule_enable row for every push_rule
+ # row. We chose the latter.
+ for_key_share = "FOR KEY SHARE"
+ if not isinstance(self.database_engine, PostgresEngine):
+ # For key share is not applicable/available on SQLite
+ for_key_share = ""
+ sql = (
+ """
+ SELECT 1 FROM push_rules
+ WHERE user_name = ? AND rule_id = ?
+ %s
+ """
+ % for_key_share
+ )
+ txn.execute(sql, (user_id, rule_id))
+ if txn.fetchone() is None:
+ # needed to set NOT_FOUND code.
+ raise NotFoundError("Push rule does not exist.")
+
self.db_pool.simple_upsert_txn(
txn,
"push_rules_enable",
@@ -606,8 +686,30 @@ class PushRuleStore(PushRulesWorkerStore):
)
async def set_push_rule_actions(
- self, user_id, rule_id, actions, is_default_rule
+ self,
+ user_id: str,
+ rule_id: str,
+ actions: List[Union[dict, str]],
+ is_default_rule: bool,
) -> None:
+ """
+ Sets the `actions` state of a push rule.
+
+ Will throw NotFoundError if the rule does not exist; the Code for this
+ is NOT_FOUND.
+
+ Args:
+ user_id: the user ID of the user who wishes to enable/disable the rule
+ e.g. '@tina:example.org'
+ rule_id: the full rule ID of the rule to be enabled/disabled
+ e.g. 'global/override/.m.rule.roomnotif'
+ or 'global/override/myCustomRule'
+ actions: A list of actions (each action being a dict or string),
+ e.g. ["notify", {"set_tweak": "highlight", "value": false}]
+ is_default_rule: True if and only if this is a server-default rule.
+ This skips the check for existence (as only user-created rules
+ are always stored in the database `push_rules` table).
+ """
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -629,12 +731,19 @@ class PushRuleStore(PushRulesWorkerStore):
update_stream=False,
)
else:
- self.db_pool.simple_update_one_txn(
- txn,
- "push_rules",
- {"user_name": user_id, "rule_id": rule_id},
- {"actions": actions_json},
- )
+ try:
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "push_rules",
+ {"user_name": user_id, "rule_id": rule_id},
+ {"actions": actions_json},
+ )
+ except StoreError as serr:
+ if serr.code == 404:
+ # this sets the NOT_FOUND error Code
+ raise NotFoundError("Push rule does not exist")
+ else:
+ raise
self._insert_push_rules_update_txn(
txn,
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
new file mode 100644
index 0000000000..b64926e9c9
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -0,0 +1,33 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is the postgres specific migration modifying the table with a background
+ * migration.
+ */
+
+-- add new index that includes method to local media
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('local_media_repository_thumbnails_method_idx', '{}');
+
+-- add new index that includes method to remote media
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+
+-- drop old index
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
new file mode 100644
index 0000000000..1d0c04b53a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
@@ -0,0 +1,44 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is a sqlite specific migration, since sqlite can't modify the unique
+ * constraint of a table without recreating it.
+ */
+
+CREATE TABLE local_media_repository_thumbnails_new ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO local_media_repository_thumbnails_new
+ SELECT media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method, thumbnail_length
+ FROM local_media_repository_thumbnails;
+
+DROP TABLE local_media_repository_thumbnails;
+
+ALTER TABLE local_media_repository_thumbnails_new RENAME TO local_media_repository_thumbnails;
+
+CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id);
+
+
+
+CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails_new ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO remote_media_cache_thumbnails_new
+ SELECT media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_method, thumbnail_type, thumbnail_length, filesystem_id
+ FROM remote_media_cache_thumbnails;
+
+DROP TABLE remote_media_cache_thumbnails;
+
+ALTER TABLE remote_media_cache_thumbnails_new RENAME TO remote_media_cache_thumbnails;
diff --git a/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
new file mode 100644
index 0000000000..847aebd85e
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ Delete stuck 'enabled' bits that correspond to deleted or non-existent push rules.
+ We ignore rules that are server-default rules because they are not defined
+ in the `push_rules` table.
+**/
+
+DELETE FROM push_rules_enable WHERE
+ rule_id NOT LIKE 'global/%/.m.rule.%'
+ AND NOT EXISTS (
+ SELECT 1 FROM push_rules
+ WHERE push_rules.user_name = push_rules_enable.user_name
+ AND push_rules.rule_id = push_rules_enable.rule_id
+ );
diff --git a/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
new file mode 100644
index 0000000000..ebfbed7925
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
@@ -0,0 +1,42 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This schema delta alters the schema to enable 'catching up' remote homeservers
+-- after there has been a connectivity problem for any reason.
+
+-- This stores, for each (destination, room) pair, the stream_ordering of the
+-- latest event for that destination.
+CREATE TABLE IF NOT EXISTS destination_rooms (
+ -- the destination in question.
+ destination TEXT NOT NULL REFERENCES destinations (destination),
+ -- the ID of the room in question
+ room_id TEXT NOT NULL REFERENCES rooms (room_id),
+ -- the stream_ordering of the event
+ stream_ordering BIGINT NOT NULL,
+ PRIMARY KEY (destination, room_id)
+ -- We don't declare a foreign key on stream_ordering here because that'd mean
+ -- we'd need to either maintain an index (expensive) or do a table scan of
+ -- destination_rooms whenever we delete an event (also potentially expensive).
+ -- In addition to that, a foreign key on stream_ordering would be redundant
+ -- as this row doesn't need to refer to a specific event; if the event gets
+ -- deleted then it doesn't affect the validity of the stream_ordering here.
+);
+
+-- This index is needed to make it so that a deletion of a room (in the rooms
+-- table) can be efficient, as otherwise a table scan would need to be performed
+-- to check that no destination_rooms rows point to the room to be deleted.
+-- Also: it makes it efficient to delete all the entries for a given room ID,
+-- such as when purging a room.
+CREATE INDEX IF NOT EXISTS destination_rooms_room_id
+ ON destination_rooms (room_id);
diff --git a/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
new file mode 100644
index 0000000000..55f5d0f732
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This delta file fixes a regression introduced by 58/12room_stats.sql, removing the hacky
+-- populate_stats_process_rooms_2 background job and restores the functionality under the
+-- original name.
+-- See https://github.com/matrix-org/synapse/issues/8238 for details
+
+DELETE FROM background_updates WHERE update_name = 'populate_stats_process_rooms';
+UPDATE background_updates SET update_name = 'populate_stats_process_rooms'
+ WHERE update_name = 'populate_stats_process_rooms_2';
diff --git a/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
new file mode 100644
index 0000000000..a67aa5e500
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
@@ -0,0 +1,21 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This column tracks the stream_ordering of the event that was most recently
+-- successfully transmitted to the destination.
+-- A value of NULL means that we have not sent an event successfully yet
+-- (at least, not since the introduction of this column).
+ALTER TABLE destinations
+ ADD COLUMN last_successful_stream_ordering BIGINT;
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 55a250ef06..30840dbbaa 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -74,9 +74,6 @@ class StatsStore(StateDeltasStore):
"populate_stats_process_rooms", self._populate_stats_process_rooms
)
self.db_pool.updates.register_background_update_handler(
- "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2
- )
- self.db_pool.updates.register_background_update_handler(
"populate_stats_process_users", self._populate_stats_process_users
)
# we no longer need to perform clean-up, but we will give ourselves
@@ -148,31 +145,10 @@ class StatsStore(StateDeltasStore):
return len(users_to_work_on)
async def _populate_stats_process_rooms(self, progress, batch_size):
- """
- This was a background update which regenerated statistics for rooms.
-
- It has been replaced by StatsStore._populate_stats_process_rooms_2. This background
- job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure
- someone upgrading from <v1.0.0, this background task has been turned into a no-op
- so that the potentially expensive task is not run twice.
-
- Further context: https://github.com/matrix-org/synapse/pull/7977
- """
- await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms"
- )
- return 1
-
- async def _populate_stats_process_rooms_2(self, progress, batch_size):
- """
- This is a background update which regenerates statistics for rooms.
-
- It replaces StatsStore._populate_stats_process_rooms. See its docstring for the
- reasoning.
- """
+ """This is a background update which regenerates statistics for rooms."""
if not self.stats_enabled:
await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms_2"
+ "populate_stats_process_rooms"
)
return 1
@@ -189,13 +165,13 @@ class StatsStore(StateDeltasStore):
return [r for r, in txn]
rooms_to_work_on = await self.db_pool.runInteraction(
- "populate_stats_rooms_2_get_batch", _get_next_batch
+ "populate_stats_rooms_get_batch", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms_2"
+ "populate_stats_process_rooms"
)
return 1
@@ -204,9 +180,9 @@ class StatsStore(StateDeltasStore):
progress["last_room_id"] = room_id
await self.db_pool.runInteraction(
- "_populate_stats_process_rooms_2",
+ "_populate_stats_process_rooms",
self.db_pool.updates._background_update_progress_txn,
- "populate_stats_process_rooms_2",
+ "populate_stats_process_rooms",
progress,
)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index be6df8a6d1..08a13a8b47 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -79,8 +79,8 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
- from_token: Optional[Tuple[int, int]],
- to_token: Optional[Tuple[int, int]],
+ from_token: Optional[Tuple[Optional[int], int]],
+ to_token: Optional[Tuple[Optional[int], int]],
engine: BaseDatabaseEngine,
) -> str:
"""Creates an SQL expression to bound the columns by the pagination
@@ -535,13 +535,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0:
return [], end_token
- end_token = RoomStreamToken.parse(end_token)
+ parsed_end_token = RoomStreamToken.parse(end_token)
rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
- from_token=end_token,
+ from_token=parsed_end_token,
limit=limit,
)
@@ -989,8 +989,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token,
- to_token=to_token,
+ from_token=from_token.as_tuple(),
+ to_token=to_token.as_tuple() if to_token else None,
engine=self.database_engine,
)
@@ -1083,16 +1083,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`).
"""
- from_key = RoomStreamToken.parse(from_key)
+ parsed_from_key = RoomStreamToken.parse(from_key)
+ parsed_to_key = None
if to_key:
- to_key = RoomStreamToken.parse(to_key)
+ parsed_to_key = RoomStreamToken.parse(to_key)
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
- from_key,
- to_key,
+ parsed_from_key,
+ parsed_to_key,
direction,
limit,
event_filter,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 5b31aab700..c0a958252e 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -15,13 +15,14 @@
import logging
from collections import namedtuple
-from typing import Optional, Tuple
+from typing import Iterable, Optional, Tuple
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
@@ -164,7 +165,9 @@ class TransactionStore(SQLBaseStore):
allow_none=True,
)
- if result and result["retry_last_ts"] > 0:
+ # check we have a row and retry_last_ts is not null or zero
+ # (retry_last_ts can't be negative)
+ if result and result["retry_last_ts"]:
return result
else:
return None
@@ -273,3 +276,98 @@ class TransactionStore(SQLBaseStore):
await self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)
+
+ async def store_destination_rooms_entries(
+ self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+ ) -> None:
+ """
+ Updates or creates `destination_rooms` entries in batch for a single event.
+
+ Args:
+ destinations: list of destinations
+ room_id: the room_id of the event
+ stream_ordering: the stream_ordering of the event
+ """
+
+ return await self.db_pool.runInteraction(
+ "store_destination_rooms_entries",
+ self._store_destination_rooms_entries_txn,
+ destinations,
+ room_id,
+ stream_ordering,
+ )
+
+ def _store_destination_rooms_entries_txn(
+ self,
+ txn: LoggingTransaction,
+ destinations: Iterable[str],
+ room_id: str,
+ stream_ordering: int,
+ ) -> None:
+
+ # ensure we have a `destinations` row for this destination, as there is
+ # a foreign key constraint.
+ if isinstance(self.database_engine, PostgresEngine):
+ q = """
+ INSERT INTO destinations (destination)
+ VALUES (?)
+ ON CONFLICT DO NOTHING;
+ """
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ q = """
+ INSERT OR IGNORE INTO destinations (destination)
+ VALUES (?);
+ """
+ else:
+ raise RuntimeError("Unknown database engine")
+
+ txn.execute_batch(q, ((destination,) for destination in destinations))
+
+ rows = [(destination, room_id) for destination in destinations]
+
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ "destination_rooms",
+ ["destination", "room_id"],
+ rows,
+ ["stream_ordering"],
+ [(stream_ordering,)] * len(rows),
+ )
+
+ async def get_destination_last_successful_stream_ordering(
+ self, destination: str
+ ) -> Optional[int]:
+ """
+ Gets the stream ordering of the PDU most-recently successfully sent
+ to the specified destination, or None if this information has not been
+ tracked yet.
+
+ Args:
+ destination: the destination to query
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ "destinations",
+ {"destination": destination},
+ "last_successful_stream_ordering",
+ allow_none=True,
+ desc="get_last_successful_stream_ordering",
+ )
+
+ async def set_destination_last_successful_stream_ordering(
+ self, destination: str, last_successful_stream_ordering: int
+ ) -> None:
+ """
+ Marks that we have successfully sent the PDUs up to and including the
+ one specified.
+
+ Args:
+ destination: the destination we have successfully sent to
+ last_successful_stream_ordering: the stream_ordering of the most
+ recent successfully-sent PDU
+ """
+ return await self.db_pool.simple_upsert(
+ "destinations",
+ keyvalues={"destination": destination},
+ values={"last_successful_stream_ordering": last_successful_stream_ordering},
+ desc="set_last_successful_stream_ordering",
+ )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index ee60e2a718..a7f2dfb850 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -19,12 +19,15 @@ import logging
import os
import re
from collections import Counter
-from typing import TextIO
+from typing import Optional, TextIO
import attr
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -63,7 +66,12 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
)
-def prepare_database(db_conn, database_engine, config, databases=["main", "state"]):
+def prepare_database(
+ db_conn: Connection,
+ database_engine: BaseDatabaseEngine,
+ config: Optional[HomeServerConfig],
+ databases: Collection[str] = ["main", "state"],
+):
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -73,16 +81,24 @@ def prepare_database(db_conn, database_engine, config, databases=["main", "state
Args:
db_conn:
database_engine:
- config (synapse.config.homeserver.HomeServerConfig|None):
+ config :
application config, or None if we are connecting to an existing
database which we expect to be configured already
- databases (list[str]): The name of the databases that will be used
+ databases: The name of the databases that will be used
with this physical database. Defaults to all databases.
"""
try:
cur = db_conn.cursor()
+ # sqlite does not automatically start transactions for DDL / SELECT statements,
+ # so we start one before running anything. This ensures that any upgrades
+ # are either applied completely, or not at all.
+ #
+ # (psycopg2 automatically starts a transaction as soon as we run any statements
+ # at all, so this is redundant but harmless there.)
+ cur.execute("BEGIN TRANSACTION")
+
logger.info("%r: Checking existing schema version", databases)
version_info = _get_or_create_schema_state(cur, database_engine)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b7eb4f8ac9..2a66b3ad4e 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -224,6 +224,10 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # Set of local IDs that we've processed that are larger than the current
+ # position, due to there being smaller unpersisted IDs.
+ self._finished_ids = set() # type: Set[int]
+
# We track the max position where we know everything before has been
# persisted. This is done by a) looking at the min across all instances
# and b) noting that if we have seen a run of persisted positions
@@ -348,17 +352,44 @@ class MultiWriterIdGenerator:
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
- current poistion if possible.
+ current position if possible.
"""
with self._lock:
self._unfinished_ids.discard(next_id)
+ self._finished_ids.add(next_id)
+
+ new_cur = None
+
+ if self._unfinished_ids:
+ # If there are unfinished IDs then the new position will be the
+ # largest finished ID less than the minimum unfinished ID.
+
+ finished = set()
+
+ min_unfinshed = min(self._unfinished_ids)
+ for s in self._finished_ids:
+ if s < min_unfinshed:
+ if new_cur is None or new_cur < s:
+ new_cur = s
+ else:
+ finished.add(s)
+
+ # We clear these out since they're now all less than the new
+ # position.
+ self._finished_ids = finished
+ else:
+ # There are no unfinished IDs so the new position is simply the
+ # largest finished one.
+ new_cur = max(self._finished_ids)
+
+ # We clear these out since they're now all less than the new
+ # position.
+ self._finished_ids.clear()
- # Figure out if its safe to advance the position by checking there
- # aren't any lower allocated IDs that are yet to finish.
- if all(c > next_id for c in self._unfinished_ids):
+ if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
- self._current_positions[self._instance_name] = max(curr, next_id)
+ self._current_positions[self._instance_name] = max(curr, new_cur)
self._add_persisted_position(next_id)
@@ -428,7 +459,7 @@ class MultiWriterIdGenerator:
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
- min_curr = min(self._current_positions.values())
+ min_curr = min(self._current_positions.values(), default=0)
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index d97dc4d101..0bdf846edf 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -14,9 +14,13 @@
# limitations under the License.
import logging
+from typing import Optional
+
+import attr
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.site import SynapseRequest
from synapse.types import StreamToken
logger = logging.getLogger(__name__)
@@ -25,38 +29,22 @@ logger = logging.getLogger(__name__)
MAX_LIMIT = 1000
-class SourcePaginationConfig:
-
- """A configuration object which stores pagination parameters for a
- specific event source."""
-
- def __init__(self, from_key=None, to_key=None, direction="f", limit=None):
- self.from_key = from_key
- self.to_key = to_key
- self.direction = "f" if direction == "f" else "b"
- self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
-
- def __repr__(self):
- return "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)" % (
- self.from_key,
- self.to_key,
- self.direction,
- self.limit,
- )
-
-
+@attr.s(slots=True)
class PaginationConfig:
-
"""A configuration object which stores pagination parameters."""
- def __init__(self, from_token=None, to_token=None, direction="f", limit=None):
- self.from_token = from_token
- self.to_token = to_token
- self.direction = "f" if direction == "f" else "b"
- self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
+ from_token = attr.ib(type=Optional[StreamToken])
+ to_token = attr.ib(type=Optional[StreamToken])
+ direction = attr.ib(type=str)
+ limit = attr.ib(type=Optional[int])
@classmethod
- def from_request(cls, request, raise_invalid_params=True, default_limit=None):
+ def from_request(
+ cls,
+ request: SynapseRequest,
+ raise_invalid_params: bool = True,
+ default_limit: Optional[int] = None,
+ ) -> "PaginationConfig":
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
from_tok = parse_string(request, "from")
@@ -78,8 +66,11 @@ class PaginationConfig:
limit = parse_integer(request, "limit", default=default_limit)
- if limit and limit < 0:
- raise SynapseError(400, "Limit must be 0 or above")
+ if limit:
+ if limit < 0:
+ raise SynapseError(400, "Limit must be 0 or above")
+
+ limit = min(int(limit), MAX_LIMIT)
try:
return PaginationConfig(from_tok, to_tok, direction, limit)
@@ -87,20 +78,10 @@ class PaginationConfig:
logger.exception("Failed to create pagination config")
raise SynapseError(400, "Invalid request.")
- def __repr__(self):
+ def __repr__(self) -> str:
return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % (
self.from_token,
self.to_token,
self.direction,
self.limit,
)
-
- def get_source_config(self, source_name):
- keyname = "%s_key" % source_name
-
- return SourcePaginationConfig(
- from_key=getattr(self.from_token, keyname),
- to_key=getattr(self.to_token, keyname) if self.to_token else None,
- direction=self.direction,
- limit=self.limit,
- )
diff --git a/synapse/types.py b/synapse/types.py
index f7de48f148..ba45335038 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -18,7 +18,7 @@ import re
import string
import sys
from collections import namedtuple
-from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
+from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
import attr
from signedjson.key import decode_verify_key_bytes
@@ -362,22 +362,79 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii")
-class StreamToken(
- namedtuple(
- "Token",
- (
- "room_key",
- "presence_key",
- "typing_key",
- "receipt_key",
- "account_data_key",
- "push_rules_key",
- "to_device_key",
- "device_list_key",
- "groups_key",
- ),
+@attr.s(frozen=True, slots=True)
+class RoomStreamToken:
+ """Tokens are positions between events. The token "s1" comes after event 1.
+
+ s0 s1
+ | |
+ [0] V [1] V [2]
+
+ Tokens can either be a point in the live event stream or a cursor going
+ through historic events.
+
+ When traversing the live event stream events are ordered by when they
+ arrived at the homeserver.
+
+ When traversing historic events the events are ordered by their depth in
+ the event graph "topological_ordering" and then by when they arrived at the
+ homeserver "stream_ordering".
+
+ Live tokens start with an "s" followed by the "stream_ordering" id of the
+ event it comes after. Historic tokens start with a "t" followed by the
+ "topological_ordering" id of the event it comes after, followed by "-",
+ followed by the "stream_ordering" id of the event it comes after.
+ """
+
+ topological = attr.ib(
+ type=Optional[int],
+ validator=attr.validators.optional(attr.validators.instance_of(int)),
)
-):
+ stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
+
+ @classmethod
+ def parse(cls, string: str) -> "RoomStreamToken":
+ try:
+ if string[0] == "s":
+ return cls(topological=None, stream=int(string[1:]))
+ if string[0] == "t":
+ parts = string[1:].split("-", 1)
+ return cls(topological=int(parts[0]), stream=int(parts[1]))
+ except Exception:
+ pass
+ raise SynapseError(400, "Invalid token %r" % (string,))
+
+ @classmethod
+ def parse_stream_token(cls, string: str) -> "RoomStreamToken":
+ try:
+ if string[0] == "s":
+ return cls(topological=None, stream=int(string[1:]))
+ except Exception:
+ pass
+ raise SynapseError(400, "Invalid token %r" % (string,))
+
+ def as_tuple(self) -> Tuple[Optional[int], int]:
+ return (self.topological, self.stream)
+
+ def __str__(self) -> str:
+ if self.topological is not None:
+ return "t%d-%d" % (self.topological, self.stream)
+ else:
+ return "s%d" % (self.stream,)
+
+
+@attr.s(slots=True, frozen=True)
+class StreamToken:
+ room_key = attr.ib(type=str)
+ presence_key = attr.ib(type=int)
+ typing_key = attr.ib(type=int)
+ receipt_key = attr.ib(type=int)
+ account_data_key = attr.ib(type=int)
+ push_rules_key = attr.ib(type=int)
+ to_device_key = attr.ib(type=int)
+ device_list_key = attr.ib(type=int)
+ groups_key = attr.ib(type=int)
+
_SEPARATOR = "_"
START = None # type: StreamToken
@@ -385,15 +442,15 @@ class StreamToken(
def from_string(cls, string):
try:
keys = string.split(cls._SEPARATOR)
- while len(keys) < len(cls._fields):
+ while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
- return cls(*keys)
+ return cls(keys[0], *(int(k) for k in keys[1:]))
except Exception:
raise SynapseError(400, "Invalid Token")
def to_string(self):
- return self._SEPARATOR.join([str(k) for k in self])
+ return self._SEPARATOR.join([str(k) for k in attr.astuple(self)])
@property
def room_stream_id(self):
@@ -435,63 +492,10 @@ class StreamToken(
return self
def copy_and_replace(self, key, new_value):
- return self._replace(**{key: new_value})
-
-
-StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))
-
-
-class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
- """Tokens are positions between events. The token "s1" comes after event 1.
-
- s0 s1
- | |
- [0] V [1] V [2]
-
- Tokens can either be a point in the live event stream or a cursor going
- through historic events.
-
- When traversing the live event stream events are ordered by when they
- arrived at the homeserver.
-
- When traversing historic events the events are ordered by their depth in
- the event graph "topological_ordering" and then by when they arrived at the
- homeserver "stream_ordering".
-
- Live tokens start with an "s" followed by the "stream_ordering" id of the
- event it comes after. Historic tokens start with a "t" followed by the
- "topological_ordering" id of the event it comes after, followed by "-",
- followed by the "stream_ordering" id of the event it comes after.
- """
+ return attr.evolve(self, **{key: new_value})
- __slots__ = [] # type: list
-
- @classmethod
- def parse(cls, string):
- try:
- if string[0] == "s":
- return cls(topological=None, stream=int(string[1:]))
- if string[0] == "t":
- parts = string[1:].split("-", 1)
- return cls(topological=int(parts[0]), stream=int(parts[1]))
- except Exception:
- pass
- raise SynapseError(400, "Invalid token %r" % (string,))
- @classmethod
- def parse_stream_token(cls, string):
- try:
- if string[0] == "s":
- return cls(topological=None, stream=int(string[1:]))
- except Exception:
- pass
- raise SynapseError(400, "Invalid token %r" % (string,))
-
- def __str__(self):
- if self.topological is not None:
- return "t%d-%d" % (self.topological, self.stream)
- else:
- return "s%d" % (self.stream,)
+StreamToken.START = StreamToken.from_string("s0_0")
class ThirdPartyInstanceID(
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 3ad4b28fc7..b2355700ad 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,11 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import logging
import re
import attr
-from canonicaljson import json
from twisted.internet import defer, task
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index bb57e27beb..67ce9a5f39 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -17,13 +17,25 @@
import collections
import logging
from contextlib import contextmanager
-from typing import Dict, Sequence, Set, Union
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Hashable,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ TypeVar,
+ Union,
+)
import attr
from typing_extensions import ContextManager
from twisted.internet import defer
from twisted.internet.defer import CancelledError
+from twisted.internet.interfaces import IReactorTime
from twisted.python import failure
from synapse.logging.context import (
@@ -54,7 +66,7 @@ class ObservableDeferred:
__slots__ = ["_deferred", "_observers", "_result"]
- def __init__(self, deferred, consumeErrors=False):
+ def __init__(self, deferred: defer.Deferred, consumeErrors: bool = False):
object.__setattr__(self, "_deferred", deferred)
object.__setattr__(self, "_result", None)
object.__setattr__(self, "_observers", set())
@@ -111,25 +123,25 @@ class ObservableDeferred:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
- def observers(self):
+ def observers(self) -> List[defer.Deferred]:
return self._observers
- def has_called(self):
+ def has_called(self) -> bool:
return self._result is not None
- def has_succeeded(self):
+ def has_succeeded(self) -> bool:
return self._result is not None and self._result[0] is True
- def get_result(self):
+ def get_result(self) -> Any:
return self._result[1]
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
return getattr(self._deferred, name)
- def __setattr__(self, name, value):
+ def __setattr__(self, name: str, value: Any) -> None:
setattr(self._deferred, name, value)
- def __repr__(self):
+ def __repr__(self) -> str:
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
id(self),
self._result,
@@ -137,18 +149,20 @@ class ObservableDeferred:
)
-def concurrently_execute(func, args, limit):
- """Executes the function with each argument conncurrently while limiting
+def concurrently_execute(
+ func: Callable, args: Iterable[Any], limit: int
+) -> defer.Deferred:
+ """Executes the function with each argument concurrently while limiting
the number of concurrent executions.
Args:
- func (func): Function to execute, should return a deferred or coroutine.
- args (Iterable): List of arguments to pass to func, each invocation of func
+ func: Function to execute, should return a deferred or coroutine.
+ args: List of arguments to pass to func, each invocation of func
gets a single argument.
- limit (int): Maximum number of conccurent executions.
+ limit: Maximum number of conccurent executions.
Returns:
- deferred: Resolved when all function invocations have finished.
+ Deferred[list]: Resolved when all function invocations have finished.
"""
it = iter(args)
@@ -167,14 +181,17 @@ def concurrently_execute(func, args, limit):
).addErrback(unwrapFirstError)
-def yieldable_gather_results(func, iter, *args, **kwargs):
+def yieldable_gather_results(
+ func: Callable, iter: Iterable, *args: Any, **kwargs: Any
+) -> defer.Deferred:
"""Executes the function with each argument concurrently.
Args:
- func (func): Function to execute that returns a Deferred
- iter (iter): An iterable that yields items that get passed as the first
+ func: Function to execute that returns a Deferred
+ iter: An iterable that yields items that get passed as the first
argument to the function
*args: Arguments to be passed to each call to func
+ **kwargs: Keyword arguments to be passed to each call to func
Returns
Deferred[list]: Resolved when all functions have been invoked, or errors if
@@ -188,24 +205,37 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
).addErrback(unwrapFirstError)
+@attr.s(slots=True)
+class _LinearizerEntry:
+ # The number of things executing.
+ count = attr.ib(type=int)
+ # Deferreds for the things blocked from executing.
+ deferreds = attr.ib(type=collections.OrderedDict)
+
+
class Linearizer:
"""Limits concurrent access to resources based on a key. Useful to ensure
only a few things happen at a time on a given resource.
Example:
- with (yield limiter.queue("test_key")):
+ with await limiter.queue("test_key"):
# do some work.
"""
- def __init__(self, name=None, max_count=1, clock=None):
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ max_count: int = 1,
+ clock: Optional[Clock] = None,
+ ):
"""
Args:
- max_count(int): The maximum number of concurrent accesses
+ max_count: The maximum number of concurrent accesses
"""
if name is None:
- self.name = id(self)
+ self.name = id(self) # type: Union[str, int]
else:
self.name = name
@@ -216,15 +246,10 @@ class Linearizer:
self._clock = clock
self.max_count = max_count
- # key_to_defer is a map from the key to a 2 element list where
- # the first element is the number of things executing, and
- # the second element is an OrderedDict, where the keys are deferreds for the
- # things blocked from executing.
- self.key_to_defer = (
- {}
- ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
+ # key_to_defer is a map from the key to a _LinearizerEntry.
+ self.key_to_defer = {} # type: Dict[Hashable, _LinearizerEntry]
- def is_queued(self, key) -> bool:
+ def is_queued(self, key: Hashable) -> bool:
"""Checks whether there is a process queued up waiting
"""
entry = self.key_to_defer.get(key)
@@ -234,25 +259,27 @@ class Linearizer:
# There are waiting deferreds only in the OrderedDict of deferreds is
# non-empty.
- return bool(entry[1])
+ return bool(entry.deferreds)
- def queue(self, key):
+ def queue(self, key: Hashable) -> defer.Deferred:
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
# (https://twistedmatrix.com/trac/ticket/4632 meant that cancellations were not
# propagated inside inlineCallbacks until Twisted 18.7)
- entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
+ entry = self.key_to_defer.setdefault(
+ key, _LinearizerEntry(0, collections.OrderedDict())
+ )
# If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items
# When one of the things currently executing finishes it will callback
# this item so that it can continue executing.
- if entry[0] >= self.max_count:
+ if entry.count >= self.max_count:
res = self._await_lock(key)
else:
logger.debug(
"Acquired uncontended linearizer lock %r for key %r", self.name, key
)
- entry[0] += 1
+ entry.count += 1
res = defer.succeed(None)
# once we successfully get the lock, we need to return a context manager which
@@ -267,15 +294,15 @@ class Linearizer:
# We've finished executing so check if there are any things
# blocked waiting to execute and start one of them
- entry[0] -= 1
+ entry.count -= 1
- if entry[1]:
- (next_def, _) = entry[1].popitem(last=False)
+ if entry.deferreds:
+ (next_def, _) = entry.deferreds.popitem(last=False)
# we need to run the next thing in the sentinel context.
with PreserveLoggingContext():
next_def.callback(None)
- elif entry[0] == 0:
+ elif entry.count == 0:
# We were the last thing for this key: remove it from the
# map.
del self.key_to_defer[key]
@@ -283,7 +310,7 @@ class Linearizer:
res.addCallback(_ctx_manager)
return res
- def _await_lock(self, key):
+ def _await_lock(self, key: Hashable) -> defer.Deferred:
"""Helper for queue: adds a deferred to the queue
Assumes that we've already checked that we've reached the limit of the number
@@ -298,11 +325,11 @@ class Linearizer:
logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer = make_deferred_yieldable(defer.Deferred())
- entry[1][new_defer] = 1
+ entry.deferreds[new_defer] = 1
def cb(_r):
logger.debug("Acquired linearizer lock %r for key %r", self.name, key)
- entry[0] += 1
+ entry.count += 1
# if the code holding the lock completes synchronously, then it
# will recursively run the next claimant on the list. That can
@@ -331,7 +358,7 @@ class Linearizer:
)
# we just have to take ourselves back out of the queue.
- del entry[1][new_defer]
+ del entry.deferreds[new_defer]
return e
new_defer.addCallbacks(cb, eb)
@@ -419,14 +446,22 @@ class ReadWriteLock:
return _ctx_manager()
-def _cancelled_to_timed_out_error(value, timeout):
+R = TypeVar("R")
+
+
+def _cancelled_to_timed_out_error(value: R, timeout: float) -> R:
if isinstance(value, failure.Failure):
value.trap(CancelledError)
raise defer.TimeoutError(timeout, "Deferred")
return value
-def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
+def timeout_deferred(
+ deferred: defer.Deferred,
+ timeout: float,
+ reactor: IReactorTime,
+ on_timeout_cancel: Optional[Callable[[Any, float], Any]] = None,
+) -> defer.Deferred:
"""The in built twisted `Deferred.addTimeout` fails to time out deferreds
that have a canceller that throws exceptions. This method creates a new
deferred that wraps and times out the given deferred, correctly handling
@@ -437,10 +472,10 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
NOTE: Unlike `Deferred.addTimeout`, this function returns a new deferred
Args:
- deferred (Deferred)
- timeout (float): Timeout in seconds
- reactor (twisted.interfaces.IReactorTime): The twisted reactor to use
- on_timeout_cancel (callable): A callable which is called immediately
+ deferred: The Deferred to potentially timeout.
+ timeout: Timeout in seconds
+ reactor: The twisted reactor to use
+ on_timeout_cancel: A callable which is called immediately
after the deferred times out, and not if this deferred is
otherwise cancelled before the timeout.
@@ -452,7 +487,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
CancelledError Failure into a defer.TimeoutError.
Returns:
- Deferred
+ A new Deferred.
"""
new_d = defer.Deferred()
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index a750261e77..f73e95393c 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -16,8 +16,6 @@ import inspect
import logging
from twisted.internet import defer
-from twisted.internet.defer import Deferred, fail, succeed
-from twisted.python import failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -29,11 +27,6 @@ def user_left_room(distributor, user, room_id):
distributor.fire("user_left_room", user=user, room_id=room_id)
-# XXX: this is no longer used. We should probably kill it.
-def user_joined_room(distributor, user, room_id):
- distributor.fire("user_joined_room", user=user, room_id=room_id)
-
-
class Distributor:
"""A central dispatch point for loosely-connected pieces of code to
register, observe, and fire signals.
@@ -81,28 +74,6 @@ class Distributor:
run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
-def maybeAwaitableDeferred(f, *args, **kw):
- """
- Invoke a function that may or may not return a Deferred or an Awaitable.
-
- This is a modified version of twisted.internet.defer.maybeDeferred.
- """
- try:
- result = f(*args, **kw)
- except Exception:
- return fail(failure.Failure(captureVars=Deferred.debug))
-
- if isinstance(result, Deferred):
- return result
- # Handle the additional case of an awaitable being returned.
- elif inspect.isawaitable(result):
- return defer.ensureDeferred(result)
- elif isinstance(result, failure.Failure):
- return fail(result)
- else:
- return succeed(result)
-
-
class Signal:
"""A Signal is a dispatch point that stores a list of callables as
observers of it.
@@ -132,22 +103,17 @@ class Signal:
Returns a Deferred that will complete when all the observers have
completed."""
- def do(observer):
- def eb(failure):
+ async def do(observer):
+ try:
+ result = observer(*args, **kwargs)
+ if inspect.isawaitable(result):
+ result = await result
+ return result
+ except Exception as e:
logger.warning(
- "%s signal observer %s failed: %r",
- self.name,
- observer,
- failure,
- exc_info=(
- failure.type,
- failure.value,
- failure.getTracebackObject(),
- ),
+ "%s signal observer %s failed: %r", self.name, observer, e,
)
- return maybeAwaitableDeferred(observer, *args, **kwargs).addErrback(eb)
-
deferreds = [run_in_background(do, o) for o in self.observers]
return make_deferred_yieldable(
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 0e445e01d7..bf094c9386 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from canonicaljson import json
+import json
+
from frozendict import frozendict
@@ -66,5 +67,5 @@ def _handle_frozendict(obj):
# A JSONEncoder which is capable of encoding frozendicts without barfing.
# Additionally reduce the whitespace produced by JSON encoding.
frozendict_json_encoder = json.JSONEncoder(
- default=_handle_frozendict, separators=(",", ":"),
+ allow_nan=False, separators=(",", ":"), default=_handle_frozendict,
)
|