diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index af8d59cf87..691f8f9adf 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -98,11 +98,14 @@ class EventTypes:
Retention = "m.room.retention"
- Presence = "m.presence"
-
Dummy = "org.matrix.dummy_event"
+class EduTypes:
+ Presence = "m.presence"
+ RoomKeyRequest = "m.room_key_request"
+
+
class RejectedReason:
AUTH_ERROR = "auth_error"
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 5d9d5a228f..c3f07bc1a3 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -14,7 +14,7 @@
# limitations under the License.
from collections import OrderedDict
-from typing import Any, Optional, Tuple
+from typing import Hashable, Optional, Tuple
from synapse.api.errors import LimitExceededError
from synapse.types import Requester
@@ -42,7 +42,9 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time
# * The point in time
# * The rate_hz of this particular entry. This can vary per request
- self.actions = OrderedDict() # type: OrderedDict[Any, Tuple[float, int, float]]
+ self.actions = (
+ OrderedDict()
+ ) # type: OrderedDict[Hashable, Tuple[float, int, float]]
def can_requester_do_action(
self,
@@ -82,7 +84,7 @@ class Ratelimiter:
def can_do_action(
self,
- key: Any,
+ key: Hashable,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
@@ -175,7 +177,7 @@ class Ratelimiter:
def ratelimit(
self,
- key: Any,
+ key: Hashable,
rate_hz: Optional[float] = None,
burst_count: Optional[int] = None,
update: bool = True,
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
index b4bd4d8e7a..9f99651aa2 100644
--- a/synapse/app/admin_cmd.py
+++ b/synapse/app/admin_cmd.py
@@ -210,7 +210,9 @@ def start(config_options):
config.update_user_directory = False
config.run_background_tasks = False
config.start_pushers = False
+ config.pusher_shard_config.instances = []
config.send_federation = False
+ config.federation_shard_config.instances = []
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 6526acb2f2..dc0d3eb725 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -645,9 +645,6 @@ class GenericWorkerServer(HomeServer):
self.get_tcp_replication().start_replication(self)
- async def remove_pusher(self, app_id, push_key, user_id):
- self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
-
@cache_in_self
def get_replication_data_handler(self):
return GenericWorkerReplicationHandler(self)
@@ -922,22 +919,6 @@ def start(config_options):
# For other worker types we force this to off.
config.appservice.notify_appservices = False
- if config.worker_app == "synapse.app.pusher":
- if config.server.start_pushers:
- sys.stderr.write(
- "\nThe pushers must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``start_pushers: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.server.start_pushers = True
- else:
- # For other worker types we force this to off.
- config.server.start_pushers = False
-
if config.worker_app == "synapse.app.user_dir":
if config.server.update_user_directory:
sys.stderr.write(
@@ -954,22 +935,6 @@ def start(config_options):
# For other worker types we force this to off.
config.server.update_user_directory = False
- if config.worker_app == "synapse.app.federation_sender":
- if config.worker.send_federation:
- sys.stderr.write(
- "\nThe send_federation must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``send_federation: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.worker.send_federation = True
- else:
- # For other worker types we force this to off.
- config.worker.send_federation = False
-
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
hs = GenericWorkerServer(
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 97399eb9ba..4026966711 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -21,7 +21,7 @@ import os
from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent
-from typing import Any, Iterable, List, MutableMapping, Optional
+from typing import Any, Iterable, List, MutableMapping, Optional, Union
import attr
import jinja2
@@ -147,7 +147,20 @@ class Config:
return int(value) * size
@staticmethod
- def parse_duration(value):
+ def parse_duration(value: Union[str, int]) -> int:
+ """Convert a duration as a string or integer to a number of milliseconds.
+
+ If an integer is provided it is treated as milliseconds and is unchanged.
+
+ String durations can have a suffix of 's', 'm', 'h', 'd', 'w', or 'y'.
+ No suffix is treated as milliseconds.
+
+ Args:
+ value: The duration to parse.
+
+ Returns:
+ The number of milliseconds in the duration.
+ """
if isinstance(value, int):
return value
second = 1000
@@ -831,22 +844,23 @@ class ShardedWorkerHandlingConfig:
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key."""
- # If multiple instances are not defined we always return true
- if not self.instances or len(self.instances) == 1:
- return True
+ # If no instances are defined we assume some other worker is handling
+ # this.
+ if not self.instances:
+ return False
- return self.get_instance(key) == instance_name
+ return self._get_instance(key) == instance_name
- def get_instance(self, key: str) -> str:
+ def _get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key.
- Note: For things like federation sending the config for which instance
- is sending is known only to the sender instance if there is only one.
- Therefore `should_handle` should be used where possible.
+ Note: For federation sending and pushers the config for which instance
+ is sending is known only to the sender instance, so we don't expose this
+ method by default.
"""
if not self.instances:
- return "master"
+ raise Exception("Unknown worker")
if len(self.instances) == 1:
return self.instances[0]
@@ -863,4 +877,21 @@ class ShardedWorkerHandlingConfig:
return self.instances[remainder]
+@attr.s
+class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
+ """A version of `ShardedWorkerHandlingConfig` that is used for config
+ options where all instances know which instances are responsible for the
+ sharded work.
+ """
+
+ def __attrs_post_init__(self):
+ # We require that `self.instances` is non-empty.
+ if not self.instances:
+ raise Exception("Got empty list of instances for shard config")
+
+ def get_instance(self, key: str) -> str:
+ """Get the instance responsible for handling the given key."""
+ return self._get_instance(key)
+
+
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 70025b5d60..db16c86f50 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -149,4 +149,6 @@ class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...
+
+class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
def get_instance(self, key: str) -> str: ...
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index 9f3c57e6a1..55e4db5442 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -41,6 +41,10 @@ class FederationConfig(Config):
)
self.federation_metrics_domains = set(federation_metrics_domains)
+ self.allow_profile_lookup_over_federation = config.get(
+ "allow_profile_lookup_over_federation", True
+ )
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
## Federation ##
@@ -66,6 +70,12 @@ class FederationConfig(Config):
#federation_metrics_domains:
# - matrix.org
# - example.com
+
+ # Uncomment to disable profile lookup over federation. By default, the
+ # Federation API allows other homeservers to obtain profile data of any user
+ # on this homeserver. Defaults to 'true'.
+ #
+ #allow_profile_lookup_over_federation: false
"""
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 3adbfb73e6..7831a2ef79 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import Config, ShardedWorkerHandlingConfig
+from ._base import Config
class PushConfig(Config):
@@ -27,9 +27,6 @@ class PushConfig(Config):
"group_unread_count_by_room", True
)
- pusher_instances = config.get("pusher_instances") or []
- self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
-
# There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade.
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index def33a60ad..847d25122c 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -102,6 +102,16 @@ class RatelimitConfig(Config):
defaults={"per_second": 0.01, "burst_count": 3},
)
+ # Ratelimit cross-user key requests:
+ # * For local requests this is keyed by the sending device.
+ # * For requests received over federation this is keyed by the origin.
+ #
+ # Note that this isn't exposed in the configuration as it is obscure.
+ self.rc_key_requests = RateLimitConfig(
+ config.get("rc_key_requests", {}),
+ defaults={"per_second": 20, "burst_count": 100},
+ )
+
self.rc_3pid_validation = RateLimitConfig(
config.get("rc_3pid_validation") or {},
defaults={"per_second": 0.003, "burst_count": 5},
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 52849c3256..69d9de5a43 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -206,7 +206,6 @@ class ContentRepositoryConfig(Config):
def generate_config_section(self, data_dir_path, **kwargs):
media_store = os.path.join(data_dir_path, "media_store")
- uploads_path = os.path.join(data_dir_path, "uploads")
formatted_thumbnail_sizes = "".join(
THUMBNAIL_SIZE_YAML % s for s in DEFAULT_THUMBNAIL_SIZES
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 6f3325ff81..2afca36e7d 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -263,6 +263,12 @@ class ServerConfig(Config):
False,
)
+ # Whether to retrieve and display profile data for a user when they
+ # are invited to a room
+ self.include_profile_data_on_invite = config.get(
+ "include_profile_data_on_invite", True
+ )
+
if "restrict_public_rooms_to_local_users" in config and (
"allow_public_rooms_without_auth" in config
or "allow_public_rooms_over_federation" in config
@@ -391,7 +397,6 @@ class ServerConfig(Config):
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
- self.start_pushers = config.get("start_pushers", True)
# (undocumented) option for torturing the worker-mode replication a bit,
# for testing. The value defines the number of milliseconds to pause before
@@ -848,6 +853,14 @@ class ServerConfig(Config):
#
#limit_profile_requests_to_users_who_share_rooms: true
+ # Uncomment to prevent a user's profile data from being retrieved and
+ # displayed in a room until they have joined it. By default, a user's
+ # profile data is included in an invite event, regardless of the values
+ # of the above two settings, and whether or not the users share a server.
+ # Defaults to 'true'.
+ #
+ #include_profile_data_on_invite: false
+
# If set to 'true', removes the need for authentication to access the server's
# public rooms directory through the client API, meaning that anyone can
# query the room directory. Defaults to 'false'.
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index c8d19c5d6b..8d05ef173c 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -24,32 +24,46 @@ class UserDirectoryConfig(Config):
section = "userdirectory"
def read_config(self, config, **kwargs):
- self.user_directory_search_enabled = True
- self.user_directory_search_all_users = False
- user_directory_config = config.get("user_directory", None)
- if user_directory_config:
- self.user_directory_search_enabled = user_directory_config.get(
- "enabled", True
- )
- self.user_directory_search_all_users = user_directory_config.get(
- "search_all_users", False
- )
+ user_directory_config = config.get("user_directory") or {}
+ self.user_directory_search_enabled = user_directory_config.get("enabled", True)
+ self.user_directory_search_all_users = user_directory_config.get(
+ "search_all_users", False
+ )
+ self.user_directory_search_prefer_local_users = user_directory_config.get(
+ "prefer_local_users", False
+ )
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# User Directory configuration
#
- # 'enabled' defines whether users can search the user directory. If
- # false then empty responses are returned to all queries. Defaults to
- # true.
- #
- # 'search_all_users' defines whether to search all users visible to your HS
- # when searching the user directory, rather than limiting to users visible
- # in public rooms. Defaults to false. If you set it True, you'll have to
- # rebuild the user_directory search indexes, see
- # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
- #
- #user_directory:
- # enabled: true
- # search_all_users: false
+ user_directory:
+ # Defines whether users can search the user directory. If false then
+ # empty responses are returned to all queries. Defaults to true.
+ #
+ # Uncomment to disable the user directory.
+ #
+ #enabled: false
+
+ # Defines whether to search all users visible to your HS when searching
+ # the user directory, rather than limiting to users visible in public
+ # rooms. Defaults to false.
+ #
+ # If you set it true, you'll have to rebuild the user_directory search
+ # indexes, see:
+ # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
+ #
+ # Uncomment to return search results containing all known users, even if that
+ # user does not share a room with the requester.
+ #
+ #search_all_users: true
+
+ # Defines whether to prefer local users in search query results.
+ # If True, local users are more likely to appear above remote users
+ # when searching the user directory. Defaults to false.
+ #
+ # Uncomment to prefer local over remote users in user directory search
+ # results.
+ #
+ #prefer_local_users: true
"""
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 7a0ca16da8..ac92375a85 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -17,9 +17,28 @@ from typing import List, Union
import attr
-from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
+from ._base import (
+ Config,
+ ConfigError,
+ RoutableShardedWorkerHandlingConfig,
+ ShardedWorkerHandlingConfig,
+)
from .server import ListenerConfig, parse_listener_def
+_FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR = """
+The send_federation config option must be disabled in the main
+synapse process before they can be run in a separate worker.
+
+Please add ``send_federation: false`` to the main config
+"""
+
+_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR = """
+The start_pushers config option must be disabled in the main
+synapse process before they can be run in a separate worker.
+
+Please add ``start_pushers: false`` to the main config
+"""
+
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
"""Helper for allowing parsing a string or list of strings to a config
@@ -103,6 +122,7 @@ class WorkerConfig(Config):
self.worker_replication_secret = config.get("worker_replication_secret", None)
self.worker_name = config.get("worker_name", self.worker_app)
+ self.instance_name = self.worker_name or "master"
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@@ -118,12 +138,41 @@ class WorkerConfig(Config):
)
)
- # Whether to send federation traffic out in this process. This only
- # applies to some federation traffic, and so shouldn't be used to
- # "disable" federation
- self.send_federation = config.get("send_federation", True)
+ # Handle federation sender configuration.
+ #
+ # There are two ways of configuring which instances handle federation
+ # sending:
+ # 1. The old way where "send_federation" is set to false and running a
+ # `synapse.app.federation_sender` worker app.
+ # 2. Specifying the workers sending federation in
+ # `federation_sender_instances`.
+ #
+
+ send_federation = config.get("send_federation", True)
+
+ federation_sender_instances = config.get("federation_sender_instances")
+ if federation_sender_instances is None:
+ # Default to an empty list, which means "another, unknown, worker is
+ # responsible for it".
+ federation_sender_instances = []
- federation_sender_instances = config.get("federation_sender_instances") or []
+ # If no federation sender instances are set we check if
+ # `send_federation` is set, which means use master
+ if send_federation:
+ federation_sender_instances = ["master"]
+
+ if self.worker_app == "synapse.app.federation_sender":
+ if send_federation:
+ # If we're running federation senders, and not using
+ # `federation_sender_instances`, then we should have
+ # explicitly set `send_federation` to false.
+ raise ConfigError(
+ _FEDERATION_SENDER_WITH_SEND_FEDERATION_ENABLED_ERROR
+ )
+
+ federation_sender_instances = [self.worker_name]
+
+ self.send_federation = self.instance_name in federation_sender_instances
self.federation_shard_config = ShardedWorkerHandlingConfig(
federation_sender_instances
)
@@ -164,7 +213,37 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `receipts` messages."
)
- self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
+ if len(self.writers.events) == 0:
+ raise ConfigError("Must specify at least one instance to handle `events`.")
+
+ self.events_shard_config = RoutableShardedWorkerHandlingConfig(
+ self.writers.events
+ )
+
+ # Handle sharded push
+ start_pushers = config.get("start_pushers", True)
+ pusher_instances = config.get("pusher_instances")
+ if pusher_instances is None:
+ # Default to an empty list, which means "another, unknown, worker is
+ # responsible for it".
+ pusher_instances = []
+
+ # If no pushers instances are set we check if `start_pushers` is
+ # set, which means use master
+ if start_pushers:
+ pusher_instances = ["master"]
+
+ if self.worker_app == "synapse.app.pusher":
+ if start_pushers:
+ # If we're running pushers, and not using
+ # `pusher_instances`, then we should have explicitly set
+ # `start_pushers` to false.
+ raise ConfigError(_PUSHER_WITH_START_PUSHERS_ENABLED_ERROR)
+
+ pusher_instances = [self.instance_name]
+
+ self.start_pushers = self.instance_name in pusher_instances
+ self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
# Whether this worker should run background tasks or not.
#
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 8d4bb621e7..2f832b47f6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -34,7 +34,7 @@ from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
@@ -44,6 +44,7 @@ from synapse.api.errors import (
SynapseError,
UnsupportedRoomVersionError,
)
+from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
@@ -869,6 +870,13 @@ class FederationHandlerRegistry:
# EDU received.
self._edu_type_to_instance = {} # type: Dict[str, List[str]]
+ # A rate limiter for incoming room key requests per origin.
+ self._room_key_request_rate_limiter = Ratelimiter(
+ clock=self.clock,
+ rate_hz=self.config.rc_key_requests.per_second,
+ burst_count=self.config.rc_key_requests.burst_count,
+ )
+
def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
):
@@ -917,7 +925,15 @@ class FederationHandlerRegistry:
self._edu_type_to_instance[edu_type] = instance_names
async def on_edu(self, edu_type: str, origin: str, content: dict):
- if not self.config.use_presence and edu_type == "m.presence":
+ if not self.config.use_presence and edu_type == EduTypes.Presence:
+ return
+
+ # If the incoming room key requests from a particular origin are over
+ # the limit, drop them.
+ if (
+ edu_type == EduTypes.RoomKeyRequest
+ and not self._room_key_request_rate_limiter.can_do_action(origin)
+ ):
return
# Check if we have a handler on this instance
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 97fc4d0a82..24ebc4b803 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -474,7 +474,7 @@ class FederationSender:
self._processing_pending_presence = False
def send_presence_to_destinations(
- self, states: List[UserPresenceState], destinations: List[str]
+ self, states: Iterable[UserPresenceState], destinations: Iterable[str]
) -> None:
"""Send the given presence states to the given destinations.
destinations (list[str])
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index cce83704d4..2cf935f38d 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -484,10 +484,9 @@ class FederationQueryServlet(BaseFederationServlet):
# This is when we receive a server-server Query
async def on_GET(self, origin, content, query, query_type):
- return await self.handler.on_query_request(
- query_type,
- {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()},
- )
+ args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}
+ args["origin"] = origin
+ return await self.handler.on_query_request(query_type, args)
class FederationMakeJoinServlet(BaseFederationServlet):
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 1aa7d803b5..7db4f48965 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -16,7 +16,9 @@
import logging
from typing import TYPE_CHECKING, Any, Dict
+from synapse.api.constants import EduTypes
from synapse.api.errors import SynapseError
+from synapse.api.ratelimiting import Ratelimiter
from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
get_active_span_text_map,
@@ -25,7 +27,7 @@ from synapse.logging.opentracing import (
start_active_span,
)
from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
-from synapse.types import JsonDict, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.stringutils import random_string
@@ -78,6 +80,12 @@ class DeviceMessageHandler:
ReplicationUserDevicesResyncRestServlet.make_client(hs)
)
+ self._ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=hs.config.rc_key_requests.per_second,
+ burst_count=hs.config.rc_key_requests.burst_count,
+ )
+
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
local_messages = {}
sender_user_id = content["sender"]
@@ -168,15 +176,27 @@ class DeviceMessageHandler:
async def send_device_message(
self,
- sender_user_id: str,
+ requester: Requester,
message_type: str,
messages: Dict[str, Dict[str, JsonDict]],
) -> None:
+ sender_user_id = requester.user.to_string()
+
set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
for user_id, by_device in messages.items():
+ # Ratelimit local cross-user key requests by the sending device.
+ if (
+ message_type == EduTypes.RoomKeyRequest
+ and user_id != sender_user_id
+ and self._ratelimiter.can_do_action(
+ (sender_user_id, requester.device_id)
+ )
+ ):
+ continue
+
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
messages_by_device = {
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 3e23f82cf7..f46cab7325 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -17,7 +17,7 @@ import logging
import random
from typing import TYPE_CHECKING, Iterable, List, Optional
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
from synapse.handlers.presence import format_user_presence_state
@@ -113,7 +113,7 @@ class EventStreamHandler(BaseHandler):
states = await presence_handler.get_states(users)
to_add.extend(
{
- "type": EventTypes.Presence,
+ "type": EduTypes.Presence,
"content": format_user_presence_state(state, time_now),
}
for state in states
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 78c3e5a10b..71a5076672 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
from twisted.internet import defer
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EduTypes, EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
@@ -412,7 +412,7 @@ class InitialSyncHandler(BaseHandler):
return [
{
- "type": EventTypes.Presence,
+ "type": EduTypes.Presence,
"content": format_user_presence_state(s, time_now),
}
for s in states
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c03f6c997b..1b7c065b34 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -387,6 +387,12 @@ class EventCreationHandler:
self.room_invite_state_types = self.hs.config.room_invite_state_types
+ self.membership_types_to_include_profile_data_in = (
+ {Membership.JOIN, Membership.INVITE}
+ if self.hs.config.include_profile_data_on_invite
+ else {Membership.JOIN}
+ )
+
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function, and maybe_kick_guest_users
@@ -500,7 +506,7 @@ class EventCreationHandler:
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
- if membership in {Membership.JOIN, Membership.INVITE}:
+ if membership in self.membership_types_to_include_profile_data_in:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index fb85b19770..b6a9ce4f38 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -849,6 +849,9 @@ class PresenceHandler(BasePresenceHandler):
"""Process current state deltas to find new joins that need to be
handled.
"""
+ # A map of destination to a set of user state that they should receive
+ presence_destinations = {} # type: Dict[str, Set[UserPresenceState]]
+
for delta in deltas:
typ = delta["type"]
state_key = delta["state_key"]
@@ -858,6 +861,7 @@ class PresenceHandler(BasePresenceHandler):
logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+ # Drop any event that isn't a membership join
if typ != EventTypes.Member:
continue
@@ -880,13 +884,38 @@ class PresenceHandler(BasePresenceHandler):
# Ignore changes to join events.
continue
- await self._on_user_joined_room(room_id, state_key)
+ # Retrieve any user presence state updates that need to be sent as a result,
+ # and the destinations that need to receive it
+ destinations, user_presence_states = await self._on_user_joined_room(
+ room_id, state_key
+ )
+
+ # Insert the destinations and respective updates into our destinations dict
+ for destination in destinations:
+ presence_destinations.setdefault(destination, set()).update(
+ user_presence_states
+ )
+
+ # Send out user presence updates for each destination
+ for destination, user_state_set in presence_destinations.items():
+ self.federation.send_presence_to_destinations(
+ destinations=[destination], states=user_state_set
+ )
- async def _on_user_joined_room(self, room_id: str, user_id: str) -> None:
+ async def _on_user_joined_room(
+ self, room_id: str, user_id: str
+ ) -> Tuple[List[str], List[UserPresenceState]]:
"""Called when we detect a user joining the room via the current state
- delta stream.
- """
+ delta stream. Returns the destinations that need to be updated and the
+ presence updates to send to them.
+
+ Args:
+ room_id: The ID of the room that the user has joined.
+ user_id: The ID of the user that has joined the room.
+ Returns:
+ A tuple of destinations and presence updates to send to them.
+ """
if self.is_mine_id(user_id):
# If this is a local user then we need to send their presence
# out to hosts in the room (who don't already have it)
@@ -894,15 +923,15 @@ class PresenceHandler(BasePresenceHandler):
# TODO: We should be able to filter the hosts down to those that
# haven't previously seen the user
- state = await self.current_state_for_user(user_id)
- hosts = await self.state.get_current_hosts_in_room(room_id)
+ remote_hosts = await self.state.get_current_hosts_in_room(room_id)
# Filter out ourselves.
- hosts = {host for host in hosts if host != self.server_name}
+ filtered_remote_hosts = [
+ host for host in remote_hosts if host != self.server_name
+ ]
- self.federation.send_presence_to_destinations(
- states=[state], destinations=hosts
- )
+ state = await self.current_state_for_user(user_id)
+ return filtered_remote_hosts, [state]
else:
# A remote user has joined the room, so we need to:
# 1. Check if this is a new server in the room
@@ -915,6 +944,8 @@ class PresenceHandler(BasePresenceHandler):
# TODO: Check that this is actually a new server joining the
# room.
+ remote_host = get_domain_from_id(user_id)
+
users = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, users))
@@ -934,10 +965,7 @@ class PresenceHandler(BasePresenceHandler):
or state.status_msg is not None
]
- if states:
- self.federation.send_presence_to_destinations(
- states=states, destinations=[get_domain_from_id(user_id)]
- )
+ return [remote_host], states
def should_notify(old_state, new_state):
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 2f62d84fb5..dd59392bda 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -310,6 +310,15 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user)
async def on_profile_query(self, args: JsonDict) -> JsonDict:
+ """Handles federation profile query requests."""
+
+ if not self.hs.config.allow_profile_lookup_over_federation:
+ raise SynapseError(
+ 403,
+ "Profile lookup over federation is disabled on this homeserver",
+ Codes.FORBIDDEN,
+ )
+
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4e8ed7b33f..ce644e01ad 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -277,8 +277,9 @@ class SyncHandler:
user_id = sync_config.user.to_string()
await self.auth.check_auth_blocking(requester=requester)
- res = await self.response_cache.wrap(
+ res = await self.response_cache.wrap_conditional(
sync_config.request_key,
+ lambda result: since_token != result.next_batch,
self._wait_for_sync_for_user,
sync_config,
since_token,
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 4a4fb5ef26..30153237e3 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -16,6 +16,10 @@ import logging
import time
from typing import Optional, Union
+import attr
+from zope.interface import implementer
+
+from twisted.internet.interfaces import IAddress
from twisted.python.failure import Failure
from twisted.web.server import Request, Site
@@ -333,26 +337,77 @@ class SynapseRequest(Request):
class XForwardedForRequest(SynapseRequest):
- def __init__(self, *args, **kw):
- SynapseRequest.__init__(self, *args, **kw)
+ """Request object which honours proxy headers
+ Extends SynapseRequest to replace getClientIP, getClientAddress, and isSecure with
+ information from request headers.
"""
- Add a layer on top of another request that only uses the value of an
- X-Forwarded-For header as the result of C{getClientIP}.
- """
- def getClientIP(self):
+ # the client IP and ssl flag, as extracted from the headers.
+ _forwarded_for = None # type: Optional[_XForwardedForAddress]
+ _forwarded_https = False # type: bool
+
+ def requestReceived(self, command, path, version):
+ # this method is called by the Channel once the full request has been
+ # received, to dispatch the request to a resource.
+ # We can use it to set the IP address and protocol according to the
+ # headers.
+ self._process_forwarded_headers()
+ return super().requestReceived(command, path, version)
+
+ def _process_forwarded_headers(self):
+ headers = self.requestHeaders.getRawHeaders(b"x-forwarded-for")
+ if not headers:
+ return
+
+ # for now, we just use the first x-forwarded-for header. Really, we ought
+ # to start from the client IP address, and check whether it is trusted; if it
+ # is, work backwards through the headers until we find an untrusted address.
+ # see https://github.com/matrix-org/synapse/issues/9471
+ self._forwarded_for = _XForwardedForAddress(
+ headers[0].split(b",")[0].strip().decode("ascii")
+ )
+
+ # if we got an x-forwarded-for header, also look for an x-forwarded-proto header
+ header = self.getHeader(b"x-forwarded-proto")
+ if header is not None:
+ self._forwarded_https = header.lower() == b"https"
+ else:
+ # this is done largely for backwards-compatibility so that people that
+ # haven't set an x-forwarded-proto header don't get a redirect loop.
+ logger.warning(
+ "forwarded request lacks an x-forwarded-proto header: assuming https"
+ )
+ self._forwarded_https = True
+
+ def isSecure(self):
+ if self._forwarded_https:
+ return True
+ return super().isSecure()
+
+ def getClientIP(self) -> str:
"""
- @return: The client address (the first address) in the value of the
- I{X-Forwarded-For header}. If the header is not present, return
- C{b"-"}.
+ Return the IP address of the client who submitted this request.
+
+ This method is deprecated. Use getClientAddress() instead.
"""
- return (
- self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0]
- .split(b",")[0]
- .strip()
- .decode("ascii")
- )
+ if self._forwarded_for is not None:
+ return self._forwarded_for.host
+ return super().getClientIP()
+
+ def getClientAddress(self) -> IAddress:
+ """
+ Return the address of the client who submitted this request.
+ """
+ if self._forwarded_for is not None:
+ return self._forwarded_for
+ return super().getClientAddress()
+
+
+@implementer(IAddress)
+@attr.s(frozen=True, slots=True)
+class _XForwardedForAddress:
+ host = attr.ib(type=str)
class SynapseSite(Site):
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index b9d3da2e0a..f4d7e199e9 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -74,6 +74,7 @@ class HttpPusher(Pusher):
self.timed_call = None
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
+ self._pusherpool = hs.get_pusherpool()
self.data = pusher_config.data
if self.data is None:
@@ -299,7 +300,7 @@ class HttpPusher(Pusher):
)
else:
logger.info("Pushkey %s was rejected: removing", pk)
- await self.hs.remove_pusher(self.app_id, pk, self.user_id)
+ await self._pusherpool.remove_pusher(self.app_id, pk, self.user_id)
return True
async def _build_notification_dict(
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index ae1145be0e..21f14f05f0 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -25,6 +25,7 @@ from synapse.metrics.background_process_metrics import (
)
from synapse.push import Pusher, PusherConfig, PusherConfigException
from synapse.push.pusher import PusherFactory
+from synapse.replication.http.push import ReplicationRemovePusherRestServlet
from synapse.types import JsonDict, RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
@@ -58,7 +59,6 @@ class PusherPool:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.pusher_factory = PusherFactory(hs)
- self._should_start_pushers = hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
@@ -67,6 +67,16 @@ class PusherPool:
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
+ self._should_start_pushers = (
+ self._instance_name in self._pusher_shard_config.instances
+ )
+
+ # We can only delete pushers on master.
+ self._remove_pusher_client = None
+ if hs.config.worker.worker_app:
+ self._remove_pusher_client = ReplicationRemovePusherRestServlet.make_client(
+ hs
+ )
# Record the last stream ID that we were poked about so we can get
# changes since then. We set this to the current max stream ID on
@@ -175,9 +185,6 @@ class PusherPool:
user_id: user to remove pushers for
access_tokens: access token *ids* to remove pushers for
"""
- if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
- return
-
tokens = set(access_tokens)
for p in await self.store.get_pushers_by_user_id(user_id):
if p.access_token in tokens:
@@ -380,6 +387,12 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
- await self.store.delete_pusher_by_app_id_pushkey_user_id(
- app_id, pushkey, user_id
- )
+ # We can only delete pushers on master.
+ if self._remove_pusher_client:
+ await self._remove_pusher_client(
+ app_id=app_id, pushkey=pushkey, user_id=user_id
+ )
+ else:
+ await self.store.delete_pusher_by_app_id_pushkey_user_id(
+ app_id, pushkey, user_id
+ )
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 8a2b73b75e..321a333820 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -106,6 +106,9 @@ CONDITIONAL_REQUIREMENTS = {
"pysaml2>=4.5.0;python_version>='3.6'",
],
"oidc": ["authlib>=0.14.0"],
+ # systemd-python is necessary for logging to the systemd journal via
+ # `systemd.journal.JournalHandler`, as is documented in
+ # `contrib/systemd/log_config.yaml`.
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
"sentry": ["sentry-sdk>=0.7.2"],
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index dd527e807f..cb4a52dbe9 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -21,6 +21,7 @@ from synapse.replication.http import (
login,
membership,
presence,
+ push,
register,
send_event,
streams,
@@ -42,6 +43,7 @@ class ReplicationRestResource(JsonResource):
membership.register_servlets(hs, self)
streams.register_servlets(hs, self)
account_data.register_servlets(hs, self)
+ push.register_servlets(hs, self)
# The following can't currently be instantiated on workers.
if hs.config.worker.worker_app is None:
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 7a0dbb5b1a..8af53b4f28 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -213,8 +213,9 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
content = parse_json_object_from_request(request)
args = content["args"]
+ args["origin"] = content["origin"]
- logger.info("Got %r query", query_type)
+ logger.info("Got %r query from %s", query_type, args["origin"])
result = await self.registry.on_query(query_type, args)
diff --git a/synapse/replication/http/push.py b/synapse/replication/http/push.py
new file mode 100644
index 0000000000..054ed64d34
--- /dev/null
+++ b/synapse/replication/http/push.py
@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.servlet import parse_json_object_from_request
+from synapse.replication.http._base import ReplicationEndpoint
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
+ """Deletes the given pusher.
+
+ Request format:
+
+ POST /_synapse/replication/remove_pusher/:user_id
+
+ {
+ "app_id": "<some_id>",
+ "pushkey": "<some_key>"
+ }
+
+ """
+
+ NAME = "add_user_account_data"
+ PATH_ARGS = ("user_id",)
+ CACHE = False
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+
+ self.pusher_pool = hs.get_pusherpool()
+
+ @staticmethod
+ async def _serialize_payload(app_id, pushkey, user_id):
+ payload = {
+ "app_id": app_id,
+ "pushkey": pushkey,
+ }
+
+ return payload
+
+ async def _handle_request(self, request, user_id):
+ content = parse_json_object_from_request(request)
+
+ app_id = content["app_id"]
+ pushkey = content["pushkey"]
+
+ await self.pusher_pool.remove_pusher(app_id, pushkey, user_id)
+
+ return 200, {}
+
+
+def register_servlets(hs, http_server):
+ ReplicationRemovePusherRestServlet(hs).register(http_server)
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 0a9da79c32..bb447f75b4 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -325,31 +325,6 @@ class FederationAckCommand(Command):
return "%s %s" % (self.instance_name, self.token)
-class RemovePusherCommand(Command):
- """Sent by the client to request the master remove the given pusher.
-
- Format::
-
- REMOVE_PUSHER <app_id> <push_key> <user_id>
- """
-
- NAME = "REMOVE_PUSHER"
-
- def __init__(self, app_id, push_key, user_id):
- self.user_id = user_id
- self.app_id = app_id
- self.push_key = push_key
-
- @classmethod
- def from_line(cls, line):
- app_id, push_key, user_id = line.split(" ", 2)
-
- return cls(app_id, push_key, user_id)
-
- def to_line(self):
- return " ".join((self.app_id, self.push_key, self.user_id))
-
-
class UserIpCommand(Command):
"""Sent periodically when a worker sees activity from a client.
@@ -416,7 +391,6 @@ _COMMANDS = (
ReplicateCommand,
UserSyncCommand,
FederationAckCommand,
- RemovePusherCommand,
UserIpCommand,
RemoteServerUpCommand,
ClearUserSyncsCommand,
@@ -443,7 +417,6 @@ VALID_CLIENT_COMMANDS = (
UserSyncCommand.NAME,
ClearUserSyncsCommand.NAME,
FederationAckCommand.NAME,
- RemovePusherCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
RemoteServerUpCommand.NAME,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index d1d00c3717..a7245da152 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -44,7 +44,6 @@ from synapse.replication.tcp.commands import (
PositionCommand,
RdataCommand,
RemoteServerUpCommand,
- RemovePusherCommand,
ReplicateCommand,
UserIpCommand,
UserSyncCommand,
@@ -373,23 +372,6 @@ class ReplicationCommandHandler:
if self._federation_sender:
self._federation_sender.federation_ack(cmd.instance_name, cmd.token)
- def on_REMOVE_PUSHER(
- self, conn: AbstractConnection, cmd: RemovePusherCommand
- ) -> Optional[Awaitable[None]]:
- remove_pusher_counter.inc()
-
- if self._is_master:
- return self._handle_remove_pusher(cmd)
- else:
- return None
-
- async def _handle_remove_pusher(self, cmd: RemovePusherCommand):
- await self._store.delete_pusher_by_app_id_pushkey_user_id(
- app_id=cmd.app_id, pushkey=cmd.push_key, user_id=cmd.user_id
- )
-
- self._notifier.on_new_replication_data()
-
def on_USER_IP(
self, conn: AbstractConnection, cmd: UserIpCommand
) -> Optional[Awaitable[None]]:
@@ -684,11 +666,6 @@ class ReplicationCommandHandler:
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
)
- def send_remove_pusher(self, app_id: str, push_key: str, user_id: str):
- """Poke the master to remove a pusher for a user"""
- cmd = RemovePusherCommand(app_id, push_key, user_id)
- self.send_command(cmd)
-
def send_user_ip(
self,
user_id: str,
diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html
index f4fdc40b22..00e1dcdbb8 100644
--- a/synapse/res/templates/sso_auth_account_details.html
+++ b/synapse/res/templates/sso_auth_account_details.html
@@ -145,7 +145,7 @@
<input type="submit" value="Continue" class="primary-button">
{% if user_attributes.avatar_url or user_attributes.display_name or user_attributes.emails %}
<section class="idp-pick-details">
- <h2><img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>Information from {{ idp.idp_name }}</h2>
+ <h2>{% if idp.idp_icon %}<img src="{{ idp.idp_icon | mxc_to_http(24, 24) }}"/>{% endif %}Information from {{ idp.idp_name }}</h2>
{% if user_attributes.avatar_url %}
<label class="idp-detail idp-avatar" for="idp-avatar">
<div class="check-row">
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 998a0ef671..9c701c7348 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -35,6 +35,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
)
from synapse.rest.client.v2_alpha._base import client_patterns
+from synapse.storage.databases.main.media_repository import MediaSortOrder
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
@@ -832,8 +833,33 @@ class UserMediaRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
+ # If neither `order_by` nor `dir` is set, set the default order
+ # to newest media is on top for backward compatibility.
+ if b"order_by" not in request.args and b"dir" not in request.args:
+ order_by = MediaSortOrder.CREATED_TS.value
+ direction = "b"
+ else:
+ order_by = parse_string(
+ request,
+ "order_by",
+ default=MediaSortOrder.CREATED_TS.value,
+ allowed_values=(
+ MediaSortOrder.MEDIA_ID.value,
+ MediaSortOrder.UPLOAD_NAME.value,
+ MediaSortOrder.CREATED_TS.value,
+ MediaSortOrder.LAST_ACCESS_TS.value,
+ MediaSortOrder.MEDIA_LENGTH.value,
+ MediaSortOrder.MEDIA_TYPE.value,
+ MediaSortOrder.QUARANTINED_BY.value,
+ MediaSortOrder.SAFE_FROM_QUARANTINE.value,
+ ),
+ )
+ direction = parse_string(
+ request, "dir", default="f", allowed_values=("f", "b")
+ )
+
media, total = await self.store.get_local_media_by_user_paginate(
- start, limit, user_id
+ start, limit, user_id, order_by, direction
)
ret = {"media": media, "total": total}
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index a3dee14ed4..79c1b526ee 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -56,10 +56,8 @@ class SendToDeviceRestServlet(servlet.RestServlet):
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("messages",))
- sender_user_id = requester.user.to_string()
-
await self.device_message_handler.send_device_message(
- sender_user_id, message_type, content["messages"]
+ requester, message_type, content["messages"]
)
response = (200, {}) # type: Tuple[int, dict]
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index a0162d4255..3375455c43 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -509,7 +509,7 @@ class MediaRepository:
t_height: int,
t_method: str,
t_type: str,
- url_cache: str,
+ url_cache: Optional[str],
) -> Optional[str]:
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(None, media_id, url_cache=url_cache)
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 1057e638be..b1b1c9e6ec 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -244,7 +244,7 @@ class MediaStorage:
await consumer.wait()
return local_path
- raise Exception("file could not be found")
+ raise NotFoundError()
def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d653a58be9..3ab90e9f9b 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -114,6 +114,7 @@ class ThumbnailResource(DirectServeJsonResource):
m_type,
thumbnail_infos,
media_id,
+ media_id,
url_cache=media_info["url_cache"],
server_name=None,
)
@@ -269,6 +270,7 @@ class ThumbnailResource(DirectServeJsonResource):
method,
m_type,
thumbnail_infos,
+ media_id,
media_info["filesystem_id"],
url_cache=None,
server_name=server_name,
@@ -282,6 +284,7 @@ class ThumbnailResource(DirectServeJsonResource):
desired_method: str,
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
+ media_id: str,
file_id: str,
url_cache: Optional[str] = None,
server_name: Optional[str] = None,
@@ -317,8 +320,59 @@ class ThumbnailResource(DirectServeJsonResource):
return
responder = await self.media_storage.fetch_media(file_info)
+ if responder:
+ await respond_with_responder(
+ request,
+ responder,
+ file_info.thumbnail_type,
+ file_info.thumbnail_length,
+ )
+ return
+
+ # If we can't find the thumbnail we regenerate it. This can happen
+ # if e.g. we've deleted the thumbnails but still have the original
+ # image somewhere.
+ #
+ # Since we have an entry for the thumbnail in the DB we a) know we
+ # have have successfully generated the thumbnail in the past (so we
+ # don't need to worry about repeatedly failing to generate
+ # thumbnails), and b) have already calculated that appropriate
+ # width/height/method so we can just call the "generate exact"
+ # methods.
+
+ # First let's check that we do actually have the original image
+ # still. This will throw a 404 if we don't.
+ # TODO: We should refetch the thumbnails for remote media.
+ await self.media_storage.ensure_media_is_in_local_cache(
+ FileInfo(server_name, file_id, url_cache=url_cache)
+ )
+
+ if server_name:
+ await self.media_repo.generate_remote_exact_thumbnail(
+ server_name,
+ file_id=file_id,
+ media_id=media_id,
+ t_width=file_info.thumbnail_width,
+ t_height=file_info.thumbnail_height,
+ t_method=file_info.thumbnail_method,
+ t_type=file_info.thumbnail_type,
+ )
+ else:
+ await self.media_repo.generate_local_exact_thumbnail(
+ media_id=media_id,
+ t_width=file_info.thumbnail_width,
+ t_height=file_info.thumbnail_height,
+ t_method=file_info.thumbnail_method,
+ t_type=file_info.thumbnail_type,
+ url_cache=url_cache,
+ )
+
+ responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(
- request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+ request,
+ responder,
+ file_info.thumbnail_type,
+ file_info.thumbnail_length,
)
else:
logger.info("Failed to find any generated thumbnails")
diff --git a/synapse/server.py b/synapse/server.py
index 6b3892e3cd..4b9ec7f0ae 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -248,7 +248,7 @@ class HomeServer(metaclass=abc.ABCMeta):
self.start_time = None # type: Optional[int]
self._instance_id = random_string(5)
- self._instance_name = config.worker_name or "master"
+ self._instance_name = config.worker.instance_name
self.version_string = version_string
@@ -758,12 +758,6 @@ class HomeServer(metaclass=abc.ABCMeta):
reconnect=True,
)
- async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
- return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
-
def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?"
- return self.config.send_federation and (
- not self.config.worker_app
- or self.config.worker_app == "synapse.app.federation_sender"
- )
+ return self.config.send_federation
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 4646926449..f1ba529a2d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
-from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection
# python 3 does not have a maximum int value
@@ -381,7 +380,10 @@ class DatabasePool:
_TXN_ID = 0
def __init__(
- self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ self,
+ hs,
+ database_config: DatabaseConnectionConfig,
+ engine: BaseDatabaseEngine,
):
self.hs = hs
self._clock = hs.get_clock()
@@ -420,16 +422,6 @@ class DatabasePool:
self._check_safe_to_upsert,
)
- # We define this sequence here so that it can be referenced from both
- # the DataStore and PersistEventStore.
- def get_chain_id_txn(txn):
- txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
-
- self.event_chain_id_gen = build_sequence_generator(
- engine, get_chain_id_txn, "event_auth_chain_id"
- )
-
def is_running(self) -> bool:
"""Is the database pool currently running"""
return self._db_pool.running
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index e84f8b42f7..379c78bb83 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -79,7 +79,7 @@ class Databases:
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
if hs.get_instance_name() in hs.config.worker.writers.events:
- persist_events = PersistEventsStore(hs, database, main)
+ persist_events = PersistEventsStore(hs, database, main, db_conn)
if "state" in database_config.databases:
logger.info(
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 287606cb4f..cd1ceac50e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -42,7 +42,9 @@ from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
+from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
@@ -90,7 +92,11 @@ class PersistEventsStore:
"""
def __init__(
- self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore"
+ self,
+ hs: "HomeServer",
+ db: DatabasePool,
+ main_data_store: "DataStore",
+ db_conn: Connection,
):
self.hs = hs
self.db_pool = db
@@ -474,6 +480,7 @@ class PersistEventsStore:
self._add_chain_cover_index(
txn,
self.db_pool,
+ self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
@@ -484,6 +491,7 @@ class PersistEventsStore:
cls,
txn,
db_pool: DatabasePool,
+ event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
@@ -630,6 +638,7 @@ class PersistEventsStore:
new_chain_tuples = cls._allocate_chain_ids(
txn,
db_pool,
+ event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
@@ -768,6 +777,7 @@ class PersistEventsStore:
def _allocate_chain_ids(
txn,
db_pool: DatabasePool,
+ event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
@@ -880,7 +890,7 @@ class PersistEventsStore:
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
# Generate new chain IDs for all unallocated chain IDs.
- newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+ newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids)
)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 89274e75f7..c1626ccf28 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
+ self.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c8850a4707..edbe42f2bf 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_list = []
self._event_fetch_ongoing = 0
+ # We define this sequence here so that it can be referenced from both
+ # the DataStore and PersistEventStore.
+ def get_chain_id_txn(txn):
+ txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+ return txn.fetchone()[0]
+
+ self.event_chain_id_gen = build_sequence_generator(
+ db_conn,
+ database.engine,
+ get_chain_id_txn,
+ "event_auth_chain_id",
+ table="event_auth_chains",
+ id_column="chain_id",
+ )
+
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index a0313c3ccf..274f8de595 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
@@ -23,6 +24,22 @@ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
)
+class MediaSortOrder(Enum):
+ """
+ Enum to define the sorting method used when returning media with
+ get_local_media_by_user_paginate
+ """
+
+ MEDIA_ID = "media_id"
+ UPLOAD_NAME = "upload_name"
+ CREATED_TS = "created_ts"
+ LAST_ACCESS_TS = "last_access_ts"
+ MEDIA_LENGTH = "media_length"
+ MEDIA_TYPE = "media_type"
+ QUARANTINED_BY = "quarantined_by"
+ SAFE_FROM_QUARANTINE = "safe_from_quarantine"
+
+
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
@@ -118,7 +135,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
async def get_local_media_by_user_paginate(
- self, start: int, limit: int, user_id: str
+ self,
+ start: int,
+ limit: int,
+ user_id: str,
+ order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
+ direction: str = "f",
) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media
which an user_id has uploaded
@@ -127,6 +149,8 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
start: offset in the list
limit: maximum amount of media_ids to retrieve
user_id: fully-qualified user id
+ order_by: the sort order of the returned list
+ direction: sort ascending or descending
Returns:
A paginated list of all metadata of user's media,
plus the total count of all the user's media
@@ -134,6 +158,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def get_local_media_by_user_paginate_txn(txn):
+ # Set ordering
+ order_by_column = MediaSortOrder(order_by).value
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
args = [user_id]
sql = """
SELECT COUNT(*) as total_media
@@ -155,9 +187,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"safe_from_quarantine"
FROM local_media_repository
WHERE user_id = ?
- ORDER BY created_ts DESC, media_id DESC
+ ORDER BY {order_by_column} {order}, media_id ASC
LIMIT ? OFFSET ?
- """
+ """.format(
+ order_by_column=order_by_column,
+ order=order,
+ )
args += [limit, start]
txn.execute(sql, args)
@@ -344,16 +379,16 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- await self.db_pool.simple_insert(
- "local_media_repository_thumbnails",
- {
+ await self.db_pool.simple_upsert(
+ table="local_media_repository_thumbnails",
+ keyvalues={
"media_id": media_id,
"thumbnail_width": thumbnail_width,
"thumbnail_height": thumbnail_height,
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
- "thumbnail_length": thumbnail_length,
},
+ values={"thumbnail_length": thumbnail_length},
desc="store_local_thumbnail",
)
@@ -498,18 +533,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- await self.db_pool.simple_insert(
- "remote_media_cache_thumbnails",
- {
+ await self.db_pool.simple_upsert(
+ table="remote_media_cache_thumbnails",
+ keyvalues={
"media_origin": origin,
"media_id": media_id,
"thumbnail_width": thumbnail_width,
"thumbnail_height": thumbnail_height,
"thumbnail_method": thumbnail_method,
"thumbnail_type": thumbnail_type,
- "thumbnail_length": thumbnail_length,
- "filesystem_id": filesystem_id,
},
+ values={"thumbnail_length": thumbnail_length},
+ insertion_values={"filesystem_id": filesystem_id},
desc="store_remote_media_thumbnail",
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d5b5507815..61a7556e56 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -23,7 +23,7 @@ import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
@@ -70,7 +70,12 @@ class TokenLookupResult:
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
+ db_conn,
database.engine,
find_max_generated_user_id_localpart,
"user_id_seq",
+ table=None,
+ id_column=None,
)
self._account_validity = hs.config.account_validity
@@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
diff --git a/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
new file mode 100644
index 0000000000..2442eea6bc
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/08delete_stale_pushers.sql
@@ -0,0 +1,19 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- Delete all pushers associated with deleted devices. This is to clear up after
+-- a bug where they weren't correctly deleted when using workers.
+DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens);
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 63f88eac51..1026f321e5 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -497,8 +497,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
async def add_users_in_public_rooms(
self, room_id: str, user_ids: Iterable[str]
) -> None:
- """Insert entries into the users_who_share_private_rooms table. The first
- user should be a local user.
+ """Insert entries into the users_in_public_rooms table.
Args:
room_id
@@ -556,6 +555,11 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
+ self._prefer_local_users_in_search = (
+ hs.config.user_directory_search_prefer_local_users
+ )
+ self._server_name = hs.config.server_name
+
async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
self.db_pool.simple_delete_txn(
@@ -665,7 +669,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- @cached()
async def get_shared_rooms_for_users(
self, user_id: str, other_user_id: str
) -> Set[str]:
@@ -754,9 +757,24 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
)
"""
+ # We allow manipulating the ranking algorithm by injecting statements
+ # based on config options.
+ additional_ordering_statements = []
+ ordering_arguments = ()
+
if isinstance(self.database_engine, PostgresEngine):
full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
+ # If enabled, this config option will rank local users higher than those on
+ # remote instances.
+ if self._prefer_local_users_in_search:
+ # This statement checks whether a given user's user ID contains a server name
+ # that matches the local server
+ statement = "* (CASE WHEN user_id LIKE ? THEN 2.0 ELSE 1.0 END)"
+ additional_ordering_statements.append(statement)
+
+ ordering_arguments += ("%:" + self._server_name,)
+
# We order by rank and then if they have profile info
# The ranking algorithm is hand tweaked for "best" results. Broadly
# the idea is we give a higher weight to exact matches.
@@ -767,7 +785,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
WHERE
- %s
+ %(where_clause)s
AND vector @@ to_tsquery('simple', ?)
ORDER BY
(CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
@@ -787,33 +805,54 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
8
)
)
+ %(order_case_statements)s
DESC,
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
- """ % (
- where_clause,
+ """ % {
+ "where_clause": where_clause,
+ "order_case_statements": " ".join(additional_ordering_statements),
+ }
+ args = (
+ join_args
+ + (full_query, exact_query, prefix_query)
+ + ordering_arguments
+ + (limit + 1,)
)
- args = join_args + (full_query, exact_query, prefix_query, limit + 1)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
+ # If enabled, this config option will rank local users higher than those on
+ # remote instances.
+ if self._prefer_local_users_in_search:
+ # This statement checks whether a given user's user ID contains a server name
+ # that matches the local server
+ #
+ # Note that we need to include a comma at the end for valid SQL
+ statement = "user_id LIKE ? DESC,"
+ additional_ordering_statements.append(statement)
+
+ ordering_arguments += ("%:" + self._server_name,)
+
sql = """
SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search as t
INNER JOIN user_directory AS d USING (user_id)
WHERE
- %s
+ %(where_clause)s
AND value MATCH ?
ORDER BY
rank(matchinfo(user_directory_search)) DESC,
+ %(order_statements)s
display_name IS NULL,
avatar_url IS NULL
LIMIT ?
- """ % (
- where_clause,
- )
- args = join_args + (search_query, limit + 1)
+ """ % {
+ "where_clause": where_clause,
+ "order_statements": " ".join(additional_ordering_statements),
+ }
+ args = join_args + (search_query,) + ordering_arguments + (limit + 1,)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index b16b9905d8..e2240703a7 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return txn.fetchone()[0]
self._state_group_seq_gen = build_sequence_generator(
- self.database_engine, get_max_state_group_txn, "state_group_id_seq"
- )
- self._state_group_seq_gen.check_consistency(
- db_conn, table="state_groups", id_column="id"
+ db_conn,
+ self.database_engine,
+ get_max_state_group_txn,
+ "state_group_id_seq",
+ table="state_groups",
+ id_column="id",
)
@cached(max_entries=10000, iterable=True)
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 3ea637b281..36a67e7019 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -251,9 +251,14 @@ class LocalSequenceGenerator(SequenceGenerator):
def build_sequence_generator(
+ db_conn: "LoggingDatabaseConnection",
database_engine: BaseDatabaseEngine,
get_first_callback: GetFirstCallbackType,
sequence_name: str,
+ table: Optional[str],
+ id_column: Optional[str],
+ stream_name: Optional[str] = None,
+ positive: bool = True,
) -> SequenceGenerator:
"""Get the best impl of SequenceGenerator available
@@ -265,8 +270,23 @@ def build_sequence_generator(
get_first_callback: a callback which gets the next sequence ID. Used if
we're on sqlite.
sequence_name: the name of a postgres sequence to use.
+ table, id_column, stream_name, positive: If set then `check_consistency`
+ is called on the created sequence. See docstring for
+ `check_consistency` details.
"""
if isinstance(database_engine, PostgresEngine):
- return PostgresSequenceGenerator(sequence_name)
+ seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
else:
- return LocalSequenceGenerator(get_first_callback)
+ seq = LocalSequenceGenerator(get_first_callback)
+
+ if table:
+ assert id_column
+ seq.check_consistency(
+ db_conn=db_conn,
+ table=table,
+ id_column=id_column,
+ stream_name=stream_name,
+ positive=positive,
+ )
+
+ return seq
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 32228f42ee..53f85195a7 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Set, TypeVar
from twisted.internet import defer
@@ -40,6 +40,7 @@ class ResponseCache(Generic[T]):
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
+ self.pending_conditionals = {} # type: Dict[T, Set[Callable[[Any], bool]]]
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.0
@@ -101,7 +102,11 @@ class ResponseCache(Generic[T]):
self.pending_result_cache[key] = result
def remove(r):
- if self.timeout_sec:
+ should_cache = all(
+ func(r) for func in self.pending_conditionals.pop(key, [])
+ )
+
+ if self.timeout_sec and should_cache:
self.clock.call_later(
self.timeout_sec, self.pending_result_cache.pop, key, None
)
@@ -112,6 +117,31 @@ class ResponseCache(Generic[T]):
result.addBoth(remove)
return result.observe()
+ def add_conditional(self, key: T, conditional: Callable[[Any], bool]):
+ self.pending_conditionals.setdefault(key, set()).add(conditional)
+
+ def wrap_conditional(
+ self,
+ key: T,
+ should_cache: Callable[[Any], bool],
+ callback: "Callable[..., Any]",
+ *args: Any,
+ **kwargs: Any
+ ) -> defer.Deferred:
+ """The same as wrap(), but adds a conditional to the final execution.
+
+ When the final execution completes, *all* conditionals need to return True for it to properly cache,
+ else it'll not be cached in a timed fashion.
+ """
+
+ # See if there's already a result on this key that hasn't yet completed. Due to the single-threaded nature of
+ # python, adding a key immediately in the same execution thread will not cause a race condition.
+ result = self.get(key)
+ if not result or isinstance(result, defer.Deferred) and not result.called:
+ self.add_conditional(key, should_cache)
+
+ return self.wrap(key, callback, *args, **kwargs)
+
def wrap(
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
) -> defer.Deferred:
|