diff --git a/synapse/__init__.py b/synapse/__init__.py
index d423856d82..3cd682f9e7 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.26.0rc1"
+__version__ = "1.26.0rc2"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 29aa064e57..8ba669059a 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -18,6 +18,7 @@ from synapse.config import (
password_auth_providers,
push,
ratelimiting,
+ redis,
registration,
repository,
room_directory,
@@ -79,6 +80,7 @@ class RootConfig:
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
+ redis: redis.RedisConfig
config_classes: List = ...
def __init__(self) -> None: ...
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index d58a83be7f..bfeceeed18 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -15,6 +15,7 @@
# limitations under the License.
import string
+from collections import Counter
from typing import Iterable, Optional, Tuple, Type
import attr
@@ -43,6 +44,16 @@ class OIDCConfig(Config):
except DependencyException as e:
raise ConfigError(e.message) from e
+ # check we don't have any duplicate idp_ids now. (The SSO handler will also
+ # check for duplicates when the REST listeners get registered, but that happens
+ # after synapse has forked so doesn't give nice errors.)
+ c = Counter([i.idp_id for i in self.oidc_providers])
+ for idp_id, count in c.items():
+ if count > 1:
+ raise ConfigError(
+ "Multiple OIDC providers have the idp_id %r." % idp_id
+ )
+
public_baseurl = self.public_baseurl
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 302b2f69bc..d330ae5dbc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,6 +18,7 @@ import copy
import itertools
import logging
from typing import (
+ TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -26,7 +27,6 @@ from typing import (
List,
Mapping,
Optional,
- Sequence,
Tuple,
TypeVar,
Union,
@@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):
class FederationClient(FederationBase):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.pdu_destination_tried = {}
+ self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
@@ -116,33 +119,32 @@ class FederationClient(FederationBase):
self.pdu_destination_tried[event_id] = destination_dict
@log_function
- def make_query(
+ async def make_query(
self,
- destination,
- query_type,
- args,
- retry_on_dns_fail=False,
- ignore_backoff=False,
- ):
+ destination: str,
+ query_type: str,
+ args: dict,
+ retry_on_dns_fail: bool = False,
+ ignore_backoff: bool = False,
+ ) -> JsonDict:
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
- destination (str): Domain name of the remote homeserver
- query_type (str): Category of the query type; should match the
+ destination: Domain name of the remote homeserver
+ query_type: Category of the query type; should match the
handler name used in register_query_handler().
- args (dict): Mapping of strings to strings containing the details
+ args: Mapping of strings to strings containing the details
of the query request.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
- a Awaitable which will eventually yield a JSON object from the
- response
+ The JSON object from the response
"""
sent_queries_counter.labels(query_type).inc()
- return self.transport_layer.make_query(
+ return await self.transport_layer.make_query(
destination,
query_type,
args,
@@ -151,42 +153,52 @@ class FederationClient(FederationBase):
)
@log_function
- def query_client_keys(self, destination, content, timeout):
+ async def query_client_keys(
+ self, destination: str, content: JsonDict, timeout: int
+ ) -> JsonDict:
"""Query device keys for a device hosted on a remote server.
Args:
- destination (str): Domain name of the remote homeserver
- content (dict): The query content.
+ destination: Domain name of the remote homeserver
+ content: The query content.
Returns:
- an Awaitable which will eventually yield a JSON object from the
- response
+ The JSON object from the response
"""
sent_queries_counter.labels("client_device_keys").inc()
- return self.transport_layer.query_client_keys(destination, content, timeout)
+ return await self.transport_layer.query_client_keys(
+ destination, content, timeout
+ )
@log_function
- def query_user_devices(self, destination, user_id, timeout=30000):
+ async def query_user_devices(
+ self, destination: str, user_id: str, timeout: int = 30000
+ ) -> JsonDict:
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.labels("user_devices").inc()
- return self.transport_layer.query_user_devices(destination, user_id, timeout)
+ return await self.transport_layer.query_user_devices(
+ destination, user_id, timeout
+ )
@log_function
- def claim_client_keys(self, destination, content, timeout):
+ async def claim_client_keys(
+ self, destination: str, content: JsonDict, timeout: int
+ ) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
Args:
- destination (str): Domain name of the remote homeserver
- content (dict): The query content.
+ destination: Domain name of the remote homeserver
+ content: The query content.
Returns:
- an Awaitable which will eventually yield a JSON object from the
- response
+ The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
- return self.transport_layer.claim_client_keys(destination, content, timeout)
+ return await self.transport_layer.claim_client_keys(
+ destination, content, timeout
+ )
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
@@ -195,10 +207,10 @@ class FederationClient(FederationBase):
given destination server.
Args:
- dest (str): The remote homeserver to ask.
- room_id (str): The room_id to backfill.
- limit (int): The maximum number of events to return.
- extremities (list): our current backwards extremities, to backfill from
+ dest: The remote homeserver to ask.
+ room_id: The room_id to backfill.
+ limit: The maximum number of events to return.
+ extremities: our current backwards extremities, to backfill from
"""
logger.debug("backfill extrem=%s", extremities)
@@ -370,7 +382,7 @@ class FederationClient(FederationBase):
for events that have failed their checks
Returns:
- Deferred : A list of PDUs that have valid signatures and hashes.
+ A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@@ -418,7 +430,9 @@ class FederationClient(FederationBase):
else:
return [p for p in valid_pdus if p]
- async def get_event_auth(self, destination, room_id, event_id):
+ async def get_event_auth(
+ self, destination: str, room_id: str, event_id: str
+ ) -> List[EventBase]:
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = await self.store.get_room_version(room_id)
@@ -700,18 +714,16 @@ class FederationClient(FederationBase):
return await self._try_destination_list("send_join", destinations, send_request)
- async def _do_send_join(self, destination: str, pdu: EventBase):
+ async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec()
try:
- content = await self.transport_layer.send_join_v2(
+ return await self.transport_layer.send_join_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
-
- return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -769,7 +781,7 @@ class FederationClient(FederationBase):
time_now = self._clock.time_msec()
try:
- content = await self.transport_layer.send_invite_v2(
+ return await self.transport_layer.send_invite_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
@@ -779,7 +791,6 @@ class FederationClient(FederationBase):
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
- return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -842,18 +853,16 @@ class FederationClient(FederationBase):
"send_leave", destinations, send_request
)
- async def _do_send_leave(self, destination, pdu):
+ async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec()
try:
- content = await self.transport_layer.send_leave_v2(
+ return await self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
-
- return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -879,7 +888,7 @@ class FederationClient(FederationBase):
# content.
return resp[1]
- def get_public_rooms(
+ async def get_public_rooms(
self,
remote_server: str,
limit: Optional[int] = None,
@@ -887,7 +896,7 @@ class FederationClient(FederationBase):
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
- ):
+ ) -> JsonDict:
"""Get the list of public rooms from a remote homeserver
Args:
@@ -901,8 +910,7 @@ class FederationClient(FederationBase):
party instance
Returns:
- Awaitable[Dict[str, Any]]: The response from the remote server, or None if
- `remote_server` is the same as the local server_name
+ The response from the remote server.
Raises:
HttpResponseException: There was an exception returned from the remote server
@@ -910,7 +918,7 @@ class FederationClient(FederationBase):
requests over federation
"""
- return self.transport_layer.get_public_rooms(
+ return await self.transport_layer.get_public_rooms(
remote_server,
limit,
since_token,
@@ -923,7 +931,7 @@ class FederationClient(FederationBase):
self,
destination: str,
room_id: str,
- earliest_events_ids: Sequence[str],
+ earliest_events_ids: Iterable[str],
latest_events: Iterable[EventBase],
limit: int,
min_depth: int,
@@ -974,7 +982,9 @@ class FederationClient(FederationBase):
return signed_events
- async def forward_third_party_invite(self, destinations, room_id, event_dict):
+ async def forward_third_party_invite(
+ self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
+ ) -> None:
for destination in destinations:
if destination == self.server_name:
continue
@@ -983,7 +993,7 @@ class FederationClient(FederationBase):
await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
- return None
+ return
except CodeMessageException:
raise
except Exception as e:
@@ -995,7 +1005,7 @@ class FederationClient(FederationBase):
async def get_room_complexity(
self, destination: str, room_id: str
- ) -> Optional[dict]:
+ ) -> Optional[JsonDict]:
"""
Fetch the complexity of a remote room from another server.
@@ -1008,10 +1018,9 @@ class FederationClient(FederationBase):
could not fetch the complexity.
"""
try:
- complexity = await self.transport_layer.get_room_complexity(
+ return await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
- return complexity
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 604cfd1935..643b26ae6d 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -142,6 +142,8 @@ class FederationSender:
self._wake_destinations_needing_catchup,
)
+ self._external_cache = hs.get_external_cache()
+
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
@@ -197,22 +199,40 @@ class FederationSender:
if not event.internal_metadata.should_proactively_send():
return
- try:
- # Get the state from before the event.
- # We need to make sure that this is the state from before
- # the event and not from after it.
- # Otherwise if the last member on a server in a room is
- # banned then it won't receive the event because it won't
- # be in the room after the ban.
- destinations = await self.state.get_hosts_in_room_at_events(
- event.room_id, event_ids=event.prev_event_ids()
- )
- except Exception:
- logger.exception(
- "Failed to calculate hosts in room for event: %s",
- event.event_id,
+ destinations = None # type: Optional[Set[str]]
+ if not event.prev_event_ids():
+ # If there are no prev event IDs then the state is empty
+ # and so no remote servers in the room
+ destinations = set()
+ else:
+ # We check the external cache for the destinations, which is
+ # stored per state group.
+
+ sg = await self._external_cache.get(
+ "event_to_prev_state_group", event.event_id
)
- return
+ if sg:
+ destinations = await self._external_cache.get(
+ "get_joined_hosts", str(sg)
+ )
+
+ if destinations is None:
+ try:
+ # Get the state from before the event.
+ # We need to make sure that this is the state from before
+ # the event and not from after it.
+ # Otherwise if the last member on a server in a room is
+ # banned then it won't receive the event because it won't
+ # be in the room after the ban.
+ destinations = await self.state.get_hosts_in_room_at_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+ except Exception:
+ logger.exception(
+ "Failed to calculate hosts in room for event: %s",
+ event.event_id,
+ )
+ return
destinations = {
d
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index fd8de8696d..b6dc7f99b6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
+ # If we are going to send this event over federation we precaclculate
+ # the joined hosts.
+ if event.internal_metadata.get_send_on_behalf_of():
+ await self.event_creation_handler.cache_joined_hosts_for_event(event)
+
return context
async def _check_for_soft_fail(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 97c4b1f262..bba53ae723 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -432,6 +432,8 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._external_cache = hs.get_external_cache()
+
async def create_event(
self,
requester: Requester,
@@ -939,6 +941,8 @@ class EventCreationHandler:
await self.action_generator.handle_push_actions_for_event(event, context)
+ await self.cache_joined_hosts_for_event(event)
+
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -978,6 +982,44 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
+ async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
+ """Precalculate the joined hosts at the event, when using Redis, so that
+ external federation senders don't have to recalculate it themselves.
+ """
+
+ if not self._external_cache.is_enabled():
+ return
+
+ # We actually store two mappings, event ID -> prev state group,
+ # state group -> joined hosts, which is much more space efficient
+ # than event ID -> joined hosts.
+ #
+ # Note: We have to cache event ID -> prev state group, as we don't
+ # store that in the DB.
+ #
+ # Note: We always set the state group -> joined hosts cache, even if
+ # we already set it, so that the expiry time is reset.
+
+ state_entry = await self.state.resolve_state_groups_for_events(
+ event.room_id, event_ids=event.prev_event_ids()
+ )
+
+ if state_entry.state_group:
+ joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+
+ await self._external_cache.set(
+ "event_to_prev_state_group",
+ event.event_id,
+ state_entry.state_group,
+ expiry_ms=60 * 60 * 1000,
+ )
+ await self._external_cache.set(
+ "get_joined_hosts",
+ str(state_entry.state_group),
+ list(joined_hosts),
+ expiry_ms=60 * 60 * 1000,
+ )
+
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None:
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 7e50341d74..04c2c1482c 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -17,7 +17,7 @@ import logging
import re
from typing import TYPE_CHECKING, Dict, Iterable, Optional
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.types import StateMap
@@ -63,7 +63,7 @@ async def calculate_room_name(
m_room_name = await store.get_event(
room_state_ids[(EventTypes.Name, "")], allow_none=True
)
- if m_room_name and m_room_name.content and m_room_name.content["name"]:
+ if m_room_name and m_room_name.content and m_room_name.content.get("name"):
return m_room_name.content["name"]
# does it have a canonical alias?
@@ -74,15 +74,11 @@ async def calculate_room_name(
if (
canon_alias
and canon_alias.content
- and canon_alias.content["alias"]
+ and canon_alias.content.get("alias")
and _looks_like_an_alias(canon_alias.content["alias"])
):
return canon_alias.content["alias"]
- # at this point we're going to need to search the state by all state keys
- # for an event type, so rearrange the data structure
- room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
-
if not fallback_to_members:
return None
@@ -94,7 +90,7 @@ async def calculate_room_name(
if (
my_member_event is not None
- and my_member_event.content["membership"] == "invite"
+ and my_member_event.content.get("membership") == Membership.INVITE
):
if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = await store.get_event(
@@ -111,6 +107,10 @@ async def calculate_room_name(
else:
return "Room Invite"
+ # at this point we're going to need to search the state by all state keys
+ # for an event type, so rearrange the data structure
+ room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
+
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
if EventTypes.Member in room_state_bytype_ids:
@@ -120,8 +120,8 @@ async def calculate_room_name(
all_members = [
ev
for ev in member_events.values()
- if ev.content["membership"] == "join"
- or ev.content["membership"] == "invite"
+ if ev.content.get("membership") == Membership.JOIN
+ or ev.content.get("membership") == Membership.INVITE
]
# Sort the member events oldest-first so the we name people in the
# order the joined (it should at least be deterministic rather than
@@ -194,11 +194,7 @@ def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
def name_from_member_event(member_event: EventBase) -> str:
- if (
- member_event.content
- and "displayname" in member_event.content
- and member_event.content["displayname"]
- ):
+ if member_event.content and member_event.content.get("displayname"):
return member_event.content["displayname"]
return member_event.state_key
diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py
new file mode 100644
index 0000000000..34fa3ff5b3
--- /dev/null
+++ b/synapse/replication/tcp/external_cache.py
@@ -0,0 +1,105 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Any, Optional
+
+from prometheus_client import Counter
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import json_decoder, json_encoder
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+set_counter = Counter(
+ "synapse_external_cache_set",
+ "Number of times we set a cache",
+ labelnames=["cache_name"],
+)
+
+get_counter = Counter(
+ "synapse_external_cache_get",
+ "Number of times we get a cache",
+ labelnames=["cache_name", "hit"],
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+class ExternalCache:
+ """A cache backed by an external Redis. Does nothing if no Redis is
+ configured.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ self._redis_connection = hs.get_outbound_redis_connection()
+
+ def _get_redis_key(self, cache_name: str, key: str) -> str:
+ return "cache_v1:%s:%s" % (cache_name, key)
+
+ def is_enabled(self) -> bool:
+ """Whether the external cache is used or not.
+
+ It's safe to use the cache when this returns false, the methods will
+ just no-op, but the function is useful to avoid doing unnecessary work.
+ """
+ return self._redis_connection is not None
+
+ async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
+ """Add the key/value to the named cache, with the expiry time given.
+ """
+
+ if self._redis_connection is None:
+ return
+
+ set_counter.labels(cache_name).inc()
+
+ # txredisapi requires the value to be string, bytes or numbers, so we
+ # encode stuff in JSON.
+ encoded_value = json_encoder.encode(value)
+
+ logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
+
+ return await make_deferred_yieldable(
+ self._redis_connection.set(
+ self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
+ )
+ )
+
+ async def get(self, cache_name: str, key: str) -> Optional[Any]:
+ """Look up a key/value in the named cache.
+ """
+
+ if self._redis_connection is None:
+ return None
+
+ result = await make_deferred_yieldable(
+ self._redis_connection.get(self._get_redis_key(cache_name, key))
+ )
+
+ logger.debug("Got cache result %s %s: %r", cache_name, key, result)
+
+ get_counter.labels(cache_name, result is not None).inc()
+
+ if not result:
+ return None
+
+ # For some reason the integers get magically converted back to integers
+ if isinstance(result, int):
+ return result
+
+ return json_decoder.decode(result)
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index 317796d5e0..8ea8dcd587 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import (
+ TYPE_CHECKING,
Any,
Awaitable,
Dict,
@@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
TypingStream,
)
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -88,7 +92,7 @@ class ReplicationCommandHandler:
back out to connections.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
@@ -282,13 +286,6 @@ class ReplicationCommandHandler:
if hs.config.redis.redis_enabled:
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
- lazyConnection,
- )
-
- logger.info(
- "Connecting to redis (host=%r port=%r)",
- hs.config.redis_host,
- hs.config.redis_port,
)
# First let's ensure that we have a ReplicationStreamer started.
@@ -299,13 +296,7 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
- outbound_redis_connection = lazyConnection(
- reactor=hs.get_reactor(),
- host=hs.config.redis_host,
- port=hs.config.redis_port,
- password=hs.config.redis.redis_password,
- reconnect=True,
- )
+ outbound_redis_connection = hs.get_outbound_redis_connection()
# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index bc6ba709a7..fdd087683b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Type, cast
import txredisapi
@@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
run_as_background_process,
+ wrap_as_background_process,
)
from synapse.replication.tcp.commands import (
Command,
@@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
immediately after initialisation.
Attributes:
- handler: The command handler to handle incoming commands.
- stream_name: The *redis* stream name to subscribe to and publish from
- (not anything to do with Synapse replication streams).
- outbound_redis_connection: The connection to redis to use to send
+ synapse_handler: The command handler to handle incoming commands.
+ synapse_stream_name: The *redis* stream name to subscribe to and publish
+ from (not anything to do with Synapse replication streams).
+ synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""
- handler = None # type: ReplicationCommandHandler
- stream_name = None # type: str
- outbound_redis_connection = None # type: txredisapi.RedisProtocol
+ synapse_handler = None # type: ReplicationCommandHandler
+ synapse_stream_name = None # type: str
+ synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
- logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
- await make_deferred_yieldable(self.subscribe(self.stream_name))
+ logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
+ await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
- self.handler.new_connection(self)
+ self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE.
- self.handler.send_positions_to_connection(self)
+ self.synapse_handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
@@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
cmd: received command
"""
- cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
+ cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
return
@@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
- self.handler.lost_connection(self)
+ self.synapse_handler.lost_connection(self)
# mark the logging context as finished
self._logging_context.__exit__(None, None, None)
@@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
await make_deferred_yieldable(
- self.outbound_redis_connection.publish(self.stream_name, encoded_string)
+ self.synapse_outbound_redis_connection.publish(
+ self.synapse_stream_name, encoded_string
+ )
+ )
+
+
+class SynapseRedisFactory(txredisapi.RedisFactory):
+ """A subclass of RedisFactory that periodically sends pings to ensure that
+ we detect dead connections.
+ """
+
+ def __init__(
+ self,
+ hs: "HomeServer",
+ uuid: str,
+ dbid: Optional[int],
+ poolsize: int,
+ isLazy: bool = False,
+ handler: Type = txredisapi.ConnectionHandler,
+ charset: str = "utf-8",
+ password: Optional[str] = None,
+ replyTimeout: int = 30,
+ convertNumbers: Optional[int] = True,
+ ):
+ super().__init__(
+ uuid=uuid,
+ dbid=dbid,
+ poolsize=poolsize,
+ isLazy=isLazy,
+ handler=handler,
+ charset=charset,
+ password=password,
+ replyTimeout=replyTimeout,
+ convertNumbers=convertNumbers,
)
+ hs.get_clock().looping_call(self._send_ping, 30 * 1000)
+
+ @wrap_as_background_process("redis_ping")
+ async def _send_ping(self):
+ for connection in self.pool:
+ try:
+ await make_deferred_yieldable(connection.ping())
+ except Exception:
+ logger.warning("Failed to send ping to a redis connection")
-class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
+
+class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
@@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
):
- super().__init__()
-
- # This sets the password on the RedisFactory base class (as
- # SubscriberFactory constructor doesn't pass it through).
- self.password = hs.config.redis.redis_password
+ super().__init__(
+ hs,
+ uuid="subscriber",
+ dbid=None,
+ poolsize=1,
+ replyTimeout=30,
+ password=hs.config.redis.redis_password,
+ )
- self.handler = hs.get_tcp_replication()
- self.stream_name = hs.hostname
+ self.synapse_handler = hs.get_tcp_replication()
+ self.synapse_stream_name = hs.hostname
- self.outbound_redis_connection = outbound_redis_connection
+ self.synapse_outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr):
- p = super().buildProtocol(addr) # type: RedisSubscriber
+ p = super().buildProtocol(addr)
+ p = cast(RedisSubscriber, p)
# We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the
# protocol.
- p.handler = self.handler
- p.outbound_redis_connection = self.outbound_redis_connection
- p.stream_name = self.stream_name
- p.password = self.password
+ p.synapse_handler = self.synapse_handler
+ p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
+ p.synapse_stream_name = self.synapse_stream_name
return p
def lazyConnection(
- reactor,
+ hs: "HomeServer",
host: str = "localhost",
port: int = 6379,
dbid: Optional[int] = None,
reconnect: bool = True,
- charset: str = "utf-8",
password: Optional[str] = None,
- connectTimeout: Optional[int] = None,
- replyTimeout: Optional[int] = None,
- convertNumbers: bool = True,
+ replyTimeout: int = 30,
) -> txredisapi.RedisProtocol:
- """Equivalent to `txredisapi.lazyConnection`, except allows specifying a
- reactor.
+ """Creates a connection to Redis that is lazily set up and reconnects if the
+ connections is lost.
"""
- isLazy = True
- poolsize = 1
-
uuid = "%s:%d" % (host, port)
- factory = txredisapi.RedisFactory(
- uuid,
- dbid,
- poolsize,
- isLazy,
- txredisapi.ConnectionHandler,
- charset,
- password,
- replyTimeout,
- convertNumbers,
+ factory = SynapseRedisFactory(
+ hs,
+ uuid=uuid,
+ dbid=dbid,
+ poolsize=1,
+ isLazy=True,
+ handler=txredisapi.ConnectionHandler,
+ password=password,
+ replyTimeout=replyTimeout,
)
factory.continueTrying = reconnect
- for x in range(poolsize):
- reactor.connectTCP(host, port, factory, connectTimeout)
+
+ reactor = hs.get_reactor()
+ reactor.connectTCP(host, port, factory, 30)
return factory.handler
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6f7dc06503..57e0a10837 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
+# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
+
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,6 +38,7 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
+ ForwardExtremitiesRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
MakeRoomAdminRestServlet,
@@ -51,6 +54,7 @@ from synapse.rest.admin.users import (
PushersRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
+ ShadowBanRestServlet,
UserAdminServlet,
UserMediaRestServlet,
UserMembershipRestServlet,
@@ -230,6 +234,8 @@ def register_servlets(hs, http_server):
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
MakeRoomAdminRestServlet(hs).register(http_server)
+ ShadowBanRestServlet(hs).register(http_server)
+ ForwardExtremitiesRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ab7cc9102a..da1499cab3 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -499,3 +499,60 @@ class MakeRoomAdminRestServlet(RestServlet):
)
return 200, {}
+
+
+class ForwardExtremitiesRestServlet(RestServlet):
+ """Allows a server admin to get or clear forward extremities.
+
+ Clearing does not require restarting the server.
+
+ Clear forward extremities:
+ DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+
+ Get forward_extremities:
+ GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.store = hs.get_datastore()
+
+ async def resolve_room_id(self, room_identifier: str) -> str:
+ """Resolve to a room ID, if necessary."""
+ if RoomID.is_valid(room_identifier):
+ resolved_room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+ resolved_room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+ if not resolved_room_id:
+ raise SynapseError(
+ 400, "Unknown room ID or room alias %s" % room_identifier
+ )
+ return resolved_room_id
+
+ async def on_DELETE(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id = await self.resolve_room_id(room_identifier)
+
+ deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
+ return 200, {"deleted": deleted_count}
+
+ async def on_GET(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id = await self.resolve_room_id(room_identifier)
+
+ extremities = await self.store.get_forward_extremities_for_room(room_id)
+ return 200, {"count": len(extremities), "results": extremities}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f39e3d6d5c..68c3c64a0d 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet):
The parameter `deactivated` can be used to include deactivated users.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
@@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet):
start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
- if len(users) >= limit:
+ if (start + limit) < total:
ret["next_token"] = str(start + len(users))
return 200, ret
@@ -875,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
)
return 200, {"access_token": token}
+
+
+class ShadowBanRestServlet(RestServlet):
+ """An admin API for shadow-banning a user.
+
+ A shadow-banned users receives successful responses to their client-server
+ API requests, but the events are not propagated into rooms.
+
+ Shadow-banning a user should be used as a tool of last resort and may lead
+ to confusing or broken behaviour for the client.
+
+ Example:
+
+ POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
+ {}
+
+ 200 OK
+ {}
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be shadow-banned")
+
+ await self.store.set_shadow_banned(UserID.from_string(user_id), True)
+
+ return 200, {}
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 31a41e4a27..f71a03a12d 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -300,6 +300,7 @@ class FileInfo:
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ thumbnail_length (int): The size of the media file, in bytes.
"""
def __init__(
@@ -312,6 +313,7 @@ class FileInfo:
thumbnail_height=None,
thumbnail_method=None,
thumbnail_type=None,
+ thumbnail_length=None,
):
self.server_name = server_name
self.file_id = file_id
@@ -321,6 +323,7 @@ class FileInfo:
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type
+ self.thumbnail_length = thumbnail_length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a632099167..bf3be653aa 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -386,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Check whether the URL should be downloaded as oEmbed content instead.
- Params:
+ Args:
url: The URL to check.
Returns:
@@ -403,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Request content from an oEmbed endpoint.
- Params:
+ Args:
endpoint: The oEmbed API endpoint.
url: The URL to pass to the API.
@@ -692,27 +692,51 @@ class PreviewUrlResource(DirectServeJsonResource):
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
+ """
+ Calculate metadata for an HTML document.
+
+ This uses lxml to parse the HTML document into the OG response. If errors
+ occur during processing of the document, an empty response is returned.
+
+ Args:
+ body: The HTML document, as bytes.
+ media_url: The URI used to download the body.
+ request_encoding: The character encoding of the body, as a string.
+
+ Returns:
+ The OG response as a dictionary.
+ """
# If there's no body, nothing useful is going to be found.
if not body:
return {}
from lxml import etree
+ # Create an HTML parser. If this fails, log and return no metadata.
try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body, parser)
- og = _calc_og(tree, media_uri)
+ except LookupError:
+ # blindly consider the encoding as utf-8.
+ parser = etree.HTMLParser(recover=True, encoding="utf-8")
+ except Exception as e:
+ logger.warning("Unable to create HTML parser: %s" % (e,))
+ return {}
+
+ def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ tree = etree.fromstring(body_attempt, parser)
+ return _calc_og(tree, media_uri)
+
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ try:
+ return _attempt_calc_og(body)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
- parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
- og = _calc_og(tree, media_uri)
-
- return og
+ return _attempt_calc_og(body.decode("utf-8", "ignore"))
-def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
+def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d6880f2e6e..d653a58be9 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,7 +16,7 @@
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.web.http import Request
@@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource):
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
- if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
- )
-
- file_info = FileInfo(
- server_name=None,
- file_id=media_id,
- url_cache=media_info["url_cache"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
-
- responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
- else:
- logger.info("Couldn't find any generated thumbnails")
- respond_404(request)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ url_cache=media_info["url_cache"],
+ server_name=None,
+ )
async def _select_or_generate_local_thumbnail(
self,
@@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_info["filesystem_id"],
+ url_cache=None,
+ server_name=server_name,
+ )
+ async def _select_and_respond_with_thumbnail(
+ self,
+ request: Request,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str] = None,
+ server_name: Optional[str] = None,
+ ) -> None:
+ """
+ Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ request: The incoming request.
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+ """
if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
+ file_info = self._select_thumbnail(
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ thumbnail_infos,
+ file_id,
+ url_cache,
+ server_name,
)
- file_info = FileInfo(
- server_name=server_name,
- file_id=media_info["filesystem_id"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
+ if not file_info:
+ logger.info("Couldn't find a thumbnail matching the desired inputs")
+ respond_404(request)
+ return
responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(
+ request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+ )
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
@@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: int,
desired_method: str,
desired_type: str,
- thumbnail_infos,
- ) -> dict:
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str],
+ server_name: Optional[str],
+ ) -> Optional[FileInfo]:
+ """
+ Choose an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+
+ Returns:
+ The thumbnail which best matches the desired parameters.
+ """
+ desired_method = desired_method.lower()
+
+ # The chosen thumbnail.
+ thumbnail_info = None
+
d_w = desired_width
d_h = desired_height
- if desired_method.lower() == "crop":
+ if desired_method == "crop":
+ # Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list = []
+ # Other thumbnails.
crop_info_list2 = []
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "crop":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
- if t_method == "crop":
- aspect_quality = abs(d_w * t_h - d_h * t_w)
- min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
- size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info["thumbnail_type"]
- length_quality = info["thumbnail_length"]
- if t_w >= d_w or t_h >= d_h:
- crop_info_list.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ aspect_quality = abs(d_w * t_h - d_h * t_w)
+ min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+ size_quality = abs((d_w - t_w) * (d_h - t_h))
+ type_quality = desired_type != info["thumbnail_type"]
+ length_quality = info["thumbnail_length"]
+ if t_w >= d_w or t_h >= d_h:
+ crop_info_list.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
- else:
- crop_info_list2.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ )
+ else:
+ crop_info_list2.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
+ )
if crop_info_list:
- return min(crop_info_list)[-1]
- else:
- return min(crop_info_list2)[-1]
- else:
+ thumbnail_info = min(crop_info_list)[-1]
+ elif crop_info_list2:
+ thumbnail_info = min(crop_info_list2)[-1]
+ elif desired_method == "scale":
+ # Thumbnails that match equal or larger sizes of desired width/height.
info_list = []
+ # Other thumbnails.
info_list2 = []
+
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "scale":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
- if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+ if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
- elif t_method == "scale":
+ else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
if info_list:
- return min(info_list)[-1]
- else:
- return min(info_list2)[-1]
+ thumbnail_info = min(info_list)[-1]
+ elif info_list2:
+ thumbnail_info = min(info_list2)[-1]
+
+ if thumbnail_info:
+ return FileInfo(
+ file_id=file_id,
+ url_cache=url_cache,
+ server_name=server_name,
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ thumbnail_length=thumbnail_info["thumbnail_length"],
+ )
+
+ # No matching thumbnail was found.
+ return None
diff --git a/synapse/server.py b/synapse/server.py
index 9cdda83aa1..9bdd3177d7 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -103,6 +103,7 @@ from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.tcp.external_cache import ExternalCache
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
+ from txredisapi import RedisProtocol
+
from synapse.handlers.oidc_handler import OidcHandler
from synapse.handlers.saml_handler import SamlHandler
@@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)
+ @cache_in_self
+ def get_external_cache(self) -> ExternalCache:
+ return ExternalCache(self)
+
+ @cache_in_self
+ def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
+ if not self.config.redis.redis_enabled:
+ return None
+
+ # We only want to import redis module if we're using it, as we have
+ # `txredisapi` as an optional dependency.
+ from synapse.replication.tcp.redis import lazyConnection
+
+ logger.info(
+ "Connecting to redis (host=%r port=%r) for external cache",
+ self.config.redis_host,
+ self.config.redis_port,
+ )
+
+ return lazyConnection(
+ hs=self,
+ host=self.config.redis_host,
+ port=self.config.redis_port,
+ password=self.config.redis.redis_password,
+ reconnect=True,
+ )
+
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 84f59c7d85..3bd9ff8ca0 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -310,6 +310,7 @@ class StateHandler:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
+ entry = None
else:
# otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
current_state_ids=state_ids_before_event,
)
- # XXX: can we update the state cache entry for the new state group? or
- # could we set a flag on resolve_state_groups_for_events to tell it to
- # always make a state group?
+ # Assign the new state group to the cached state entry.
+ #
+ # Note that this can race in that we could generate multiple state
+ # groups for the same state entry, but that is just inefficient
+ # rather than dangerous.
+ if entry and entry.state_group is None:
+ entry.state_group = state_group_before_event
#
# now if it's not a state event, we're done
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a19d65ad23..d2ba4bd2fc 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -262,13 +262,18 @@ class LoggingTransaction:
return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
+ """Similar to `executemany`, except `txn.rowcount` will not be correct
+ afterwards.
+
+ More efficient than `executemany` on PostgreSQL
+ """
+
if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
- for val in args:
- self.execute(sql, val)
+ self.executemany(sql, args)
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
@@ -888,7 +893,7 @@ class DatabasePool:
", ".join("?" for _ in keys[0]),
)
- txn.executemany(sql, vals)
+ txn.execute_batch(sql, vals)
async def simple_upsert(
self,
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index ae561a2da3..5d0845588c 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -43,6 +43,7 @@ from .end_to_end_keys import EndToEndKeyStore
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events_bg_updates import EventsBackgroundUpdatesStore
+from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
@@ -118,6 +119,7 @@ class DataStore(
UIAuthStore,
CacheInvalidationWorkerStore,
ServerMetricsStore,
+ EventForwardExtremitiesStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 9097677648..659d8f245f 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ?
"""
- txn.executemany(sql, ((row[0], row[1]) for row in rows))
+ txn.execute_batch(sql, ((row[0], row[1]) for row in rows))
logger.info("Pruned %d device list outbound pokes", count)
@@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Delete older entries in the table, as we really only care about
# when the latest change happened.
- txn.executemany(
+ txn.execute_batch(
"""
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 1b657191a9..438383abe1 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?, ?)
"""
- txn.executemany(
+ txn.execute_batch(
sql,
(
_gen_entry(user_id, actions)
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
],
)
- txn.executemany(
+ txn.execute_batch(
"""
UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3216b3f3c8..ccda9f1caa 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -473,8 +473,9 @@ class PersistEventsStore:
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
)
- @staticmethod
+ @classmethod
def _add_chain_cover_index(
+ cls,
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
@@ -614,60 +615,17 @@ class PersistEventsStore:
if not events_to_calc_chain_id_for:
return
- # We now calculate the chain IDs/sequence numbers for the events. We
- # do this by looking at the chain ID and sequence number of any auth
- # event with the same type/state_key and incrementing the sequence
- # number by one. If there was no match or the chain ID/sequence
- # number is already taken we generate a new chain.
- #
- # We need to do this in a topologically sorted order as we want to
- # generate chain IDs/sequence numbers of an event's auth events
- # before the event itself.
- chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
- new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
- for event_id in sorted_topologically(
- events_to_calc_chain_id_for, event_to_auth_chain
- ):
- existing_chain_id = None
- for auth_id in event_to_auth_chain.get(event_id, []):
- if event_to_types.get(event_id) == event_to_types.get(auth_id):
- existing_chain_id = chain_map[auth_id]
- break
-
- new_chain_tuple = None
- if existing_chain_id:
- # We found a chain ID/sequence number candidate, check its
- # not already taken.
- proposed_new_id = existing_chain_id[0]
- proposed_new_seq = existing_chain_id[1] + 1
- if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
- already_allocated = db_pool.simple_select_one_onecol_txn(
- txn,
- table="event_auth_chains",
- keyvalues={
- "chain_id": proposed_new_id,
- "sequence_number": proposed_new_seq,
- },
- retcol="event_id",
- allow_none=True,
- )
- if already_allocated:
- # Mark it as already allocated so we don't need to hit
- # the DB again.
- chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
- else:
- new_chain_tuple = (
- proposed_new_id,
- proposed_new_seq,
- )
-
- if not new_chain_tuple:
- new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
-
- chains_tuples_allocated.add(new_chain_tuple)
-
- chain_map[event_id] = new_chain_tuple
- new_chain_tuples[event_id] = new_chain_tuple
+ # Allocate chain ID/sequence numbers to each new event.
+ new_chain_tuples = cls._allocate_chain_ids(
+ txn,
+ db_pool,
+ event_to_room_id,
+ event_to_types,
+ event_to_auth_chain,
+ events_to_calc_chain_id_for,
+ chain_map,
+ )
+ chain_map.update(new_chain_tuples)
db_pool.simple_insert_many_txn(
txn,
@@ -794,6 +752,137 @@ class PersistEventsStore:
],
)
+ @staticmethod
+ def _allocate_chain_ids(
+ txn,
+ db_pool: DatabasePool,
+ event_to_room_id: Dict[str, str],
+ event_to_types: Dict[str, Tuple[str, str]],
+ event_to_auth_chain: Dict[str, List[str]],
+ events_to_calc_chain_id_for: Set[str],
+ chain_map: Dict[str, Tuple[int, int]],
+ ) -> Dict[str, Tuple[int, int]]:
+ """Allocates, but does not persist, chain ID/sequence numbers for the
+ events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
+ for info on args)
+ """
+
+ # We now calculate the chain IDs/sequence numbers for the events. We do
+ # this by looking at the chain ID and sequence number of any auth event
+ # with the same type/state_key and incrementing the sequence number by
+ # one. If there was no match or the chain ID/sequence number is already
+ # taken we generate a new chain.
+ #
+ # We try to reduce the number of times that we hit the database by
+ # batching up calls, to make this more efficient when persisting large
+ # numbers of state events (e.g. during joins).
+ #
+ # We do this by:
+ # 1. Calculating for each event which auth event will be used to
+ # inherit the chain ID, i.e. converting the auth chain graph to a
+ # tree that we can allocate chains on. We also keep track of which
+ # existing chain IDs have been referenced.
+ # 2. Fetching the max allocated sequence number for each referenced
+ # existing chain ID, generating a map from chain ID to the max
+ # allocated sequence number.
+ # 3. Iterating over the tree and allocating a chain ID/seq no. to the
+ # new event, by incrementing the sequence number from the
+ # referenced event's chain ID/seq no. and checking that the
+ # incremented sequence number hasn't already been allocated (by
+ # looking in the map generated in the previous step). We generate a
+ # new chain if the sequence number has already been allocated.
+ #
+
+ existing_chains = set() # type: Set[int]
+ tree = [] # type: List[Tuple[str, Optional[str]]]
+
+ # We need to do this in a topologically sorted order as we want to
+ # generate chain IDs/sequence numbers of an event's auth events before
+ # the event itself.
+ for event_id in sorted_topologically(
+ events_to_calc_chain_id_for, event_to_auth_chain
+ ):
+ for auth_id in event_to_auth_chain.get(event_id, []):
+ if event_to_types.get(event_id) == event_to_types.get(auth_id):
+ existing_chain_id = chain_map.get(auth_id)
+ if existing_chain_id:
+ existing_chains.add(existing_chain_id[0])
+
+ tree.append((event_id, auth_id))
+ break
+ else:
+ tree.append((event_id, None))
+
+ # Fetch the current max sequence number for each existing referenced chain.
+ sql = """
+ SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
+ WHERE %s
+ GROUP BY chain_id
+ """
+ clause, args = make_in_list_sql_clause(
+ db_pool.engine, "chain_id", existing_chains
+ )
+ txn.execute(sql % (clause,), args)
+
+ chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
+
+ # Allocate the new events chain ID/sequence numbers.
+ #
+ # To reduce the number of calls to the database we don't allocate a
+ # chain ID number in the loop, instead we use a temporary `object()` for
+ # each new chain ID. Once we've done the loop we generate the necessary
+ # number of new chain IDs in one call, replacing all temporary
+ # objects with real allocated chain IDs.
+
+ unallocated_chain_ids = set() # type: Set[object]
+ new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
+ for event_id, auth_event_id in tree:
+ # If we reference an auth_event_id we fetch the allocated chain ID,
+ # either from the existing `chain_map` or the newly generated
+ # `new_chain_tuples` map.
+ existing_chain_id = None
+ if auth_event_id:
+ existing_chain_id = new_chain_tuples.get(auth_event_id)
+ if not existing_chain_id:
+ existing_chain_id = chain_map[auth_event_id]
+
+ new_chain_tuple = None # type: Optional[Tuple[Any, int]]
+ if existing_chain_id:
+ # We found a chain ID/sequence number candidate, check its
+ # not already taken.
+ proposed_new_id = existing_chain_id[0]
+ proposed_new_seq = existing_chain_id[1] + 1
+
+ if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
+ new_chain_tuple = (
+ proposed_new_id,
+ proposed_new_seq,
+ )
+
+ # If we need to start a new chain we allocate a temporary chain ID.
+ if not new_chain_tuple:
+ new_chain_tuple = (object(), 1)
+ unallocated_chain_ids.add(new_chain_tuple[0])
+
+ new_chain_tuples[event_id] = new_chain_tuple
+ chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
+
+ # Generate new chain IDs for all unallocated chain IDs.
+ newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+ txn, len(unallocated_chain_ids)
+ )
+
+ # Map from potentially temporary chain ID to real chain ID
+ chain_id_to_allocated_map = dict(
+ zip(unallocated_chain_ids, newly_allocated_chain_ids)
+ ) # type: Dict[Any, int]
+ chain_id_to_allocated_map.update((c, c) for c in existing_chains)
+
+ return {
+ event_id: (chain_id_to_allocated_map[chain_id], seq)
+ for event_id, (chain_id, seq) in new_chain_tuples.items()
+ }
+
def _persist_transaction_ids_txn(
self,
txn: LoggingTransaction,
@@ -876,7 +965,7 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
- txn.executemany(
+ txn.execute_batch(
sql,
(
(
@@ -895,7 +984,7 @@ class PersistEventsStore:
)
# Now we actually update the current_state_events table
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?",
(
@@ -907,7 +996,7 @@ class PersistEventsStore:
# We include the membership in the current state table, hence we do
# a lookup when we insert. This assumes that all events have already
# been inserted into room_memberships.
- txn.executemany(
+ txn.execute_batch(
"""INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership)
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -927,7 +1016,7 @@ class PersistEventsStore:
# we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it.
if to_delete or to_insert:
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?",
(
@@ -938,7 +1027,7 @@ class PersistEventsStore:
)
if to_insert:
- txn.executemany(
+ txn.execute_batch(
"""INSERT INTO local_current_membership
(room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -1738,7 +1827,7 @@ class PersistEventsStore:
"""
if events_and_contexts:
- txn.executemany(
+ txn.execute_batch(
sql,
(
(
@@ -1767,7 +1856,7 @@ class PersistEventsStore:
# Now we delete the staging area for *all* events that were being
# persisted.
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
((event.event_id,) for event, _ in all_events_and_contexts),
)
@@ -1886,7 +1975,7 @@ class PersistEventsStore:
" )"
)
- txn.executemany(
+ txn.execute_batch(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1900,7 +1989,7 @@ class PersistEventsStore:
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
- txn.executemany(
+ txn.execute_batch(
query,
[
(ev.event_id, ev.room_id)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e46e44ba54..5ca4fa6817 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -139,8 +139,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
-
def reindex_txn(txn):
sql = (
"SELECT stream_ordering, event_id, json FROM events"
@@ -178,9 +176,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
- for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
- clump = update_rows[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ txn.execute_batch(sql, update_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -210,8 +206,6 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
-
def reindex_search_txn(txn):
sql = (
"SELECT stream_ordering, event_id FROM events"
@@ -256,9 +250,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
- for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
- clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ txn.execute_batch(sql, rows_to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/events_forward_extremities.py b/synapse/storage/databases/main/events_forward_extremities.py
new file mode 100644
index 0000000000..0ac1da9c35
--- /dev/null
+++ b/synapse/storage/databases/main/events_forward_extremities.py
@@ -0,0 +1,101 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Dict, List
+
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventForwardExtremitiesStore(SQLBaseStore):
+ async def delete_forward_extremities_for_room(self, room_id: str) -> int:
+ """Delete any extra forward extremities for a room.
+
+ Invalidates the "get_latest_event_ids_in_room" cache if any forward
+ extremities were deleted.
+
+ Returns count deleted.
+ """
+
+ def delete_forward_extremities_for_room_txn(txn):
+ # First we need to get the event_id to not delete
+ sql = """
+ SELECT event_id FROM event_forward_extremities
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id = ?
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ """
+ txn.execute(sql, (room_id,))
+ rows = txn.fetchall()
+ try:
+ event_id = rows[0][0]
+ logger.debug(
+ "Found event_id %s as the forward extremity to keep for room %s",
+ event_id,
+ room_id,
+ )
+ except KeyError:
+ msg = "No forward extremity event found for room %s" % room_id
+ logger.warning(msg)
+ raise SynapseError(400, msg)
+
+ # Now delete the extra forward extremities
+ sql = """
+ DELETE FROM event_forward_extremities
+ WHERE event_id != ? AND room_id = ?
+ """
+
+ txn.execute(sql, (event_id, room_id))
+ logger.info(
+ "Deleted %s extra forward extremities for room %s",
+ txn.rowcount,
+ room_id,
+ )
+
+ if txn.rowcount > 0:
+ # Invalidate the cache
+ self._invalidate_cache_and_stream(
+ txn, self.get_latest_event_ids_in_room, (room_id,),
+ )
+
+ return txn.rowcount
+
+ return await self.db_pool.runInteraction(
+ "delete_forward_extremities_for_room",
+ delete_forward_extremities_for_room_txn,
+ )
+
+ async def get_forward_extremities_for_room(self, room_id: str) -> List[Dict]:
+ """Get list of forward extremities for a room."""
+
+ def get_forward_extremities_for_room_txn(txn):
+ sql = """
+ SELECT event_id, state_group, depth, received_ts
+ FROM event_forward_extremities
+ INNER JOIN event_to_state_groups USING (event_id)
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id = ?
+ """
+
+ txn.execute(sql, (room_id,))
+ return self.db_pool.cursor_to_dict(txn)
+
+ return await self.db_pool.runInteraction(
+ "get_forward_extremities_for_room", get_forward_extremities_for_room_txn,
+ )
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 283c8a5e22..e017177655 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_origin = ? AND media_id = ?"
)
- txn.executemany(
+ txn.execute_batch(
sql,
(
(time_ms, media_origin, media_id)
@@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_id = ?"
)
- txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
+ txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn
@@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn
@@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_media_txn(txn):
sql = "DELETE FROM local_media_repository WHERE media_id = ?"
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
- txn.executemany(sql, [(media_id,) for media_id in media_ids])
+ txn.execute_batch(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 5d668aadb2..ecfc9f20b1 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
)
# Update backward extremeties
- txn.executemany(
+ txn.execute_batch(
"INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems],
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index bc7621b8d6..2687ef3e43 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -344,7 +344,9 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
- self.db_pool.simple_delete_one_txn(
+ # It is expected that there is exactly one pusher to delete, but
+ # if it isn't there (or there are multiple) delete them all.
+ self.db_pool.simple_delete_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 8d05288ed4..0618b4387a 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -360,6 +360,35 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
+ async def set_shadow_banned(self, user: UserID, shadow_banned: bool) -> None:
+ """Sets whether a user shadow-banned.
+
+ Args:
+ user: user ID of the user to test
+ shadow_banned: true iff the user is to be shadow-banned, false otherwise.
+ """
+
+ def set_shadow_banned_txn(txn):
+ self.db_pool.simple_update_one_txn(
+ txn,
+ table="users",
+ keyvalues={"name": user.to_string()},
+ updatevalues={"shadow_banned": shadow_banned},
+ )
+ # In order for this to apply immediately, clear the cache for this user.
+ tokens = self.db_pool.simple_select_onecol_txn(
+ txn,
+ table="access_tokens",
+ keyvalues={"user_id": user.to_string()},
+ retcol="token",
+ )
+ for token in tokens:
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_access_token, (token,)
+ )
+
+ await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
+
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = """
SELECT users.name as user_id,
@@ -1104,7 +1133,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
FROM user_threepids
"""
- txn.executemany(sql, [(id_server,) for id_server in id_servers])
+ txn.execute_batch(sql, [(id_server,) for id_server in id_servers])
if id_servers:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index dcdaf09682..92382bed28 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -873,8 +873,6 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive", self._stream_order_on_start + 1
)
- INSERT_CLUMP_SIZE = 1000
-
def add_membership_profile_txn(txn):
sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -915,9 +913,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ?
"""
- for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
- clump = to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(to_update_sql, clump)
+ txn.execute_batch(to_update_sql, to_update)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
diff --git a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
index f35c70b699..9e8f35c1d2 100644
--- a/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
+++ b/synapse/storage/databases/main/schema/delta/59/01ignored_user.py
@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
# { "ignored_users": "@someone:example.org": {} }
ignored_users = content.get("ignored_users", {})
if isinstance(ignored_users, dict) and ignored_users:
- cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
+ cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users])
# Add indexes after inserting data for efficiency.
logger.info("Adding constraints to ignored_users table")
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 141207fb16..edde12fbd2 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -63,7 +63,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.executemany(sql, args)
+ txn.execute_batch(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
@@ -75,7 +75,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries
)
- txn.executemany(sql, args)
+ txn.execute_batch(sql, args)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 0e31cc811a..89cdc84a9c 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
)
logger.info("[purge] removing redundant state groups")
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM state_groups_state WHERE state_group = ?",
((sg,) for sg in state_groups_to_delete),
)
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM state_groups WHERE id = ?",
((sg,) for sg in state_groups_to_delete),
)
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index bb84c0d792..71ef5a72dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
import heapq
import logging
import threading
-from collections import deque
+from collections import OrderedDict
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union
import attr
-from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque() # type: Deque[int]
+
+ # We use this as an ordered set, as we want to efficiently append items,
+ # remove items and get the first item. Since we insert IDs in order, the
+ # insertion ordering will ensure its in the correct ordering.
+ #
+ # The key and values are the same, but we never look at the values.
+ self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
def get_next(self):
"""
@@ -113,7 +118,7 @@ class StreamIdGenerator:
self._current += self._step
next_id = self._current
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
yield next_id
finally:
with self._lock:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -140,7 +145,7 @@ class StreamIdGenerator:
self._current += n * self._step
for next_id in next_ids:
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -149,7 +154,7 @@ class StreamIdGenerator:
finally:
with self._lock:
for next_id in next_ids:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -162,7 +167,7 @@ class StreamIdGenerator:
"""
with self._lock:
if self._unfinished_ids:
- return self._unfinished_ids[0] - self._step
+ return next(iter(self._unfinished_ids)) - self._step
return self._current
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index c780ade077..0ec4dc2918 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -70,6 +70,11 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
...
@abc.abstractmethod
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ """Get the next `n` IDs in the sequence"""
+ ...
+
+ @abc.abstractmethod
def check_consistency(
self,
db_conn: "LoggingDatabaseConnection",
@@ -219,6 +224,17 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ first_id = self._current_max_id + 1
+ self._current_max_id += n
+ return [first_id + i for i in range(n)]
+
def check_consistency(
self,
db_conn: Connection,
|