diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 5853a54c95..1c9456e583 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,13 +17,15 @@
"""Contains exceptions and error codes."""
import logging
-from typing import Dict
+from typing import Dict, List
from six import iteritems
from six.moves import http_client
from canonicaljson import json
+from twisted.web import http
+
logger = logging.getLogger(__name__)
@@ -80,6 +82,29 @@ class CodeMessageException(RuntimeError):
self.msg = msg
+class RedirectException(CodeMessageException):
+ """A pseudo-error indicating that we want to redirect the client to a different
+ location
+
+ Attributes:
+ cookies: a list of set-cookies values to add to the response. For example:
+ b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT"
+ """
+
+ def __init__(self, location: bytes, http_code: int = http.FOUND):
+ """
+
+ Args:
+ location: the URI to redirect to
+ http_code: the HTTP response code
+ """
+ msg = "Redirect to %s" % (location.decode("utf-8"),)
+ super().__init__(code=http_code, msg=msg)
+ self.location = location
+
+ self.cookies = [] # type: List[bytes]
+
+
class SynapseError(CodeMessageException):
"""A base exception type for matrix errors which have an errcode and error
message (as well as an HTTP status code).
@@ -158,12 +183,6 @@ class UserDeactivatedError(SynapseError):
)
-class RegistrationError(SynapseError):
- """An error raised when a registration event fails."""
-
- pass
-
-
class FederationDeniedError(SynapseError):
"""An error raised when the server tries to federate with a server which
is not on its federation whitelist.
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e5b44a5eed..c2a334a2b0 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -31,7 +31,7 @@ from prometheus_client import Gauge
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
-from twisted.web.resource import EncodingResourceWrapper, NoResource
+from twisted.web.resource import EncodingResourceWrapper, IResource, NoResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@@ -109,7 +109,16 @@ class SynapseHomeServer(HomeServer):
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
handler = handler_cls(config, module_api)
- resources[path] = AdditionalResource(self, handler.handle_request)
+ if IResource.providedBy(handler):
+ resource = handler
+ elif hasattr(handler, "handle_request"):
+ resource = AdditionalResource(self, handler.handle_request)
+ else:
+ raise ConfigError(
+ "additional_resource %s does not implement a known interface"
+ % (resmodule["module"],)
+ )
+ resources[path] = resource
# try to find something useful to redirect '/' to
if WEB_CLIENT_PREFIX in resources:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d7ce333822..8eddb3bf2c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Dict
import six
from six import iteritems
@@ -22,6 +23,7 @@ from six import iteritems
from canonicaljson import json
from prometheus_client import Counter
+from twisted.internet import defer
from twisted.internet.abstract import isIPAddress
from twisted.python import failure
@@ -41,7 +43,11 @@ from synapse.federation.federation_base import FederationBase, event_from_pdu_js
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
-from synapse.logging.context import nested_logging_context
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ nested_logging_context,
+ run_in_background,
+)
from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
@@ -49,7 +55,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.types import get_domain_from_id
-from synapse.util import glob_to_regex
+from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -160,6 +166,43 @@ class FederationServer(FederationBase):
)
return 400, response
+ # We process PDUs and EDUs in parallel. This is important as we don't
+ # want to block things like to device messages from reaching clients
+ # behind the potentially expensive handling of PDUs.
+ pdu_results, _ = await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self._handle_pdus_in_txn, origin, transaction, request_time
+ ),
+ run_in_background(self._handle_edus_in_txn, origin, transaction),
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ response = {"pdus": pdu_results}
+
+ logger.debug("Returning: %s", str(response))
+
+ await self.transaction_actions.set_response(origin, transaction, 200, response)
+ return 200, response
+
+ async def _handle_pdus_in_txn(
+ self, origin: str, transaction: Transaction, request_time: int
+ ) -> Dict[str, dict]:
+ """Process the PDUs in a received transaction.
+
+ Args:
+ origin: the server making the request
+ transaction: incoming transaction
+ request_time: timestamp that the HTTP request arrived at
+
+ Returns:
+ A map from event ID of a processed PDU to any errors we should
+ report back to the sending server.
+ """
+
received_pdus_counter.inc(len(transaction.pdus))
origin_host, _ = parse_server_name(origin)
@@ -250,20 +293,23 @@ class FederationServer(FederationBase):
process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
- if hasattr(transaction, "edus"):
- for edu in (Edu(**x) for x in transaction.edus):
- await self.received_edu(origin, edu.edu_type, edu.content)
+ return pdu_results
- response = {"pdus": pdu_results}
+ async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
+ """Process the EDUs in a received transaction.
+ """
- logger.debug("Returning: %s", str(response))
+ async def _process_edu(edu_dict):
+ received_edus_counter.inc()
- await self.transaction_actions.set_response(origin, transaction, 200, response)
- return 200, response
+ edu = Edu(**edu_dict)
+ await self.registry.on_edu(edu.edu_type, origin, edu.content)
- async def received_edu(self, origin, edu_type, content):
- received_edus_counter.inc()
- await self.registry.on_edu(edu_type, origin, content)
+ await concurrently_execute(
+ _process_edu,
+ getattr(transaction, "edus", []),
+ TRANSACTION_CONCURRENCY_LIMIT,
+ )
async def on_context_state_request(self, origin, room_id, event_id):
origin_host, _ = parse_server_name(origin)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 76d18a8ba8..a9407553b4 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -134,7 +134,7 @@ class AdminHandler(BaseHandler):
The returned value is that returned by `writer.finished()`.
"""
# Get all rooms the user is in or has been in
- rooms = await self.store.get_rooms_for_user_where_membership_is(
+ rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user_id,
membership_list=(
Membership.JOIN,
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 4426967f88..2afb390a92 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -140,7 +140,7 @@ class DeactivateAccountHandler(BaseHandler):
user_id (str): The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
- pending_invites = await self.store.get_invited_rooms_for_user(user_id)
+ pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
for room in pending_invites:
try:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 44ec3e66ae..2e6755f19c 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -101,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
if include_archived:
memberships.append(Membership.LEAVE)
- room_list = await self.store.get_rooms_for_user_where_membership_is(
+ room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 8a7d965feb..7ffc194f0c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -20,13 +20,7 @@ from twisted.internet import defer
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, LoginType
-from synapse.api.errors import (
- AuthError,
- Codes,
- ConsentNotGivenError,
- RegistrationError,
- SynapseError,
-)
+from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -165,7 +159,7 @@ class RegistrationHandler(BaseHandler):
Returns:
Deferred[str]: user_id
Raises:
- RegistrationError if there was a problem registering.
+ SynapseError if there was a problem registering.
"""
yield self.check_registration_ratelimit(address)
@@ -174,7 +168,7 @@ class RegistrationHandler(BaseHandler):
if password:
password_hash = yield self._auth_handler.hash(password)
- if localpart:
+ if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token)
was_guest = guest_access_token is not None
@@ -182,7 +176,7 @@ class RegistrationHandler(BaseHandler):
if not was_guest:
try:
int(localpart)
- raise RegistrationError(
+ raise SynapseError(
400, "Numeric user IDs are reserved for guest users."
)
except ValueError:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 03bb52ccfb..15e8aa5249 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -690,7 +690,7 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _get_inviter(self, user_id, room_id):
- invite = yield self.store.get_invite_for_user_in_room(
+ invite = yield self.store.get_invite_for_local_user_in_room(
user_id=user_id, room_id=room_id
)
if invite:
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 0082f85c26..107f97032b 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,6 +24,7 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.http.servlet import parse_string
+from synapse.module_api import ModuleApi
from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import (
UserID,
@@ -59,7 +60,8 @@ class SamlHandler:
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
- hs.config.saml2_user_mapping_provider_config
+ hs.config.saml2_user_mapping_provider_config,
+ ModuleApi(hs, hs.get_auth_handler()),
)
# identifier for the external_ids table
@@ -112,10 +114,10 @@ class SamlHandler:
# the dict.
self.expire_sessions()
- user_id = await self._map_saml_response_to_user(resp_bytes)
+ user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
- async def _map_saml_response_to_user(self, resp_bytes):
+ async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
@@ -183,7 +185,7 @@ class SamlHandler:
# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
- saml2_auth, i
+ saml2_auth, i, client_redirect_url=client_redirect_url,
)
logger.debug(
@@ -216,6 +218,8 @@ class SamlHandler:
500, "Unable to generate a Matrix ID from the SAML response"
)
+ logger.info("Mapped SAML user to local part %s", localpart)
+
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=displayname
)
@@ -265,17 +269,21 @@ class SamlConfig(object):
class DefaultSamlMappingProvider(object):
__version__ = "0.0.1"
- def __init__(self, parsed_config: SamlConfig):
+ def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
"""The default SAML user mapping provider
Args:
parsed_config: Module configuration
+ module_api: module api proxy
"""
self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper
def saml_response_to_user_attributes(
- self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
+ self,
+ saml_response: saml2.response.AuthnResponse,
+ failures: int,
+ client_redirect_url: str,
) -> dict:
"""Maps some text from a SAML response to attributes of a new user
@@ -285,6 +293,8 @@ class DefaultSamlMappingProvider(object):
failures: How many times a call to this function with this
saml_response has resulted in a failure
+ client_redirect_url: where the client wants to redirect to
+
Returns:
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ef750d1497..110097eab9 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -179,7 +179,7 @@ class SearchHandler(BaseHandler):
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
- rooms = yield self.store.get_rooms_for_user_where_membership_is(
+ rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 2d3b8ba73c..cd95f85e3f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1662,7 +1662,7 @@ class SyncHandler(object):
Membership.BAN,
)
- room_list = await self.store.get_rooms_for_user_where_membership_is(
+ room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, membership_list=membership_list
)
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 943d12c907..04bc2385a2 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,8 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import cgi
import collections
+import html
import http.client
import logging
import types
@@ -36,6 +36,7 @@ import synapse.metrics
from synapse.api.errors import (
CodeMessageException,
Codes,
+ RedirectException,
SynapseError,
UnrecognizedRequestError,
)
@@ -153,14 +154,18 @@ def _return_html_error(f, request):
Args:
f (twisted.python.failure.Failure):
- request (twisted.web.iweb.IRequest):
+ request (twisted.web.server.Request):
"""
if f.check(CodeMessageException):
cme = f.value
code = cme.code
msg = cme.msg
- if isinstance(cme, SynapseError):
+ if isinstance(cme, RedirectException):
+ logger.info("%s redirect to %s", request, cme.location)
+ request.setHeader(b"location", cme.location)
+ request.cookies.extend(cme.cookies)
+ elif isinstance(cme, SynapseError):
logger.info("%s SynapseError: %s - %s", request, code, msg)
else:
logger.error(
@@ -178,7 +183,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
- body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8")
+ body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % (len(body),))
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 9f2d035fa0..911251c0bc 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -88,7 +88,7 @@ class SynapseRequest(Request):
def get_redacted_uri(self):
uri = self.uri
if isinstance(uri, bytes):
- uri = self.uri.decode("ascii")
+ uri = self.uri.decode("ascii", errors="replace")
return redact_uri(uri)
def get_method(self):
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 33b322209d..1b940842f6 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -571,6 +571,9 @@ def run_in_background(f, *args, **kwargs):
yield or await on (for instance because you want to pass it to
deferred.gatherResults()).
+ If f returns a Coroutine object, it will be wrapped into a Deferred (which will have
+ the side effect of executing the coroutine).
+
Note that if you completely discard the result, you should make sure that
`f` doesn't raise any deferred exceptions, otherwise a scary-looking
CRITICAL error about an unhandled error will be logged without much
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index de5c101a58..5dae4648c0 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -21,7 +21,7 @@ from synapse.storage import Storage
@defer.inlineCallbacks
def get_badge_count(store, user_id):
- invites = yield store.get_invited_rooms_for_user(user_id)
+ invites = yield store.get_invited_rooms_for_local_user(user_id)
joins = yield store.get_rooms_for_user(user_id)
my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index c8056b0c0c..444eb7b7f4 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,6 +16,7 @@
import abc
import logging
import re
+from typing import Dict, List, Tuple
from six import raise_from
from six.moves import urllib
@@ -78,9 +79,8 @@ class ReplicationEndpoint(object):
__metaclass__ = abc.ABCMeta
- NAME = abc.abstractproperty()
- PATH_ARGS = abc.abstractproperty()
-
+ NAME = abc.abstractproperty() # type: str # type: ignore
+ PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
METHOD = "POST"
CACHE = True
RETRY_ON_TIMEOUT = True
@@ -171,7 +171,7 @@ class ReplicationEndpoint(object):
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
- headers = {}
+ headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = yield request_func(uri, data, headers=headers)
@@ -207,7 +207,7 @@ class ReplicationEndpoint(object):
method = self.METHOD
if self.CACHE:
- handler = self._cached_handler
+ handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index b91a528245..704282c800 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Dict
+from typing import Dict, Optional
import six
@@ -41,7 +41,7 @@ class BaseSlavedStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id"
- )
+ ) # type: Optional[SlavedIdTracker]
else:
self._cache_id_gen = None
@@ -62,7 +62,8 @@ class BaseSlavedStore(SQLBaseStore):
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
- self._cache_id_gen.advance(token)
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(token)
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
room_id = row.keys[0]
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 29f35b9915..3aa6cb8b96 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -152,7 +152,7 @@ class SlavedEventStore(
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
- self.get_invited_rooms_for_user.invalidate((state_key,))
+ self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index f552e7c972..ad8f0c15a9 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -29,7 +29,7 @@ class SlavedPresenceStore(BaseSlavedStore):
self._presence_on_startup = self._get_active_presence(db_conn)
- self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
+ self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index bbcb84646c..aa7fd90e26 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,7 +16,7 @@
"""
import logging
-from typing import Dict
+from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
@@ -28,6 +28,7 @@ from synapse.replication.tcp.protocol import (
)
from .commands import (
+ Command,
FederationAckCommand,
InvalidateCacheCommand,
RemovePusherCommand,
@@ -89,15 +90,15 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
# Any pending commands to be sent once a new connection has been
# established
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
- self.awaiting_syncs = {}
+ self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
- self.factory = None
+ self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@@ -235,4 +236,5 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
- self.factory.resetDelay()
+ if self.factory:
+ self.factory.resetDelay()
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 0ff2a7199f..cbb36b9acf 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -20,15 +20,16 @@ allowed to be sent by which side.
import logging
import platform
+from typing import Tuple, Type
if platform.python_implementation() == "PyPy":
import json
_json_encoder = json.JSONEncoder()
else:
- import simplejson as json
+ import simplejson as json # type: ignore[no-redef] # noqa: F821
- _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
logger = logging.getLogger(__name__)
@@ -44,7 +45,7 @@ class Command(object):
The default implementation creates a command of form `<NAME> <data>`
"""
- NAME = None
+ NAME = None # type: str
def __init__(self, data):
self.data = data
@@ -386,25 +387,24 @@ class UserIpCommand(Command):
)
+_COMMANDS = (
+ ServerCommand,
+ RdataCommand,
+ PositionCommand,
+ ErrorCommand,
+ PingCommand,
+ NameCommand,
+ ReplicateCommand,
+ UserSyncCommand,
+ FederationAckCommand,
+ SyncCommand,
+ RemovePusherCommand,
+ InvalidateCacheCommand,
+ UserIpCommand,
+) # type: Tuple[Type[Command], ...]
+
# Map of command name to command type.
-COMMAND_MAP = {
- cmd.NAME: cmd
- for cmd in (
- ServerCommand,
- RdataCommand,
- PositionCommand,
- ErrorCommand,
- PingCommand,
- NameCommand,
- ReplicateCommand,
- UserSyncCommand,
- FederationAckCommand,
- SyncCommand,
- RemovePusherCommand,
- InvalidateCacheCommand,
- UserIpCommand,
- )
-}
+COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
# The commands the server is allowed to send
VALID_SERVER_COMMANDS = (
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index afaf002fe6..db0353c996 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -53,6 +53,7 @@ import fcntl
import logging
import struct
from collections import defaultdict
+from typing import Any, DefaultDict, Dict, List, Set, Tuple
from six import iteritems, iterkeys
@@ -65,13 +66,11 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util import Clock
-from synapse.util.stringutils import random_string
-
-from .commands import (
+from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
+ Command,
ErrorCommand,
NameCommand,
PingCommand,
@@ -82,6 +81,10 @@ from .commands import (
SyncCommand,
UserSyncCommand,
)
+from synapse.types import Collection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
+
from .streams import STREAMS_MAP
connection_close_counter = Counter(
@@ -124,8 +127,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
delimiter = b"\n"
- VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
- VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
+ # Valid commands we expect to receive
+ VALID_INBOUND_COMMANDS = [] # type: Collection[str]
+
+ # Valid commands we can send
+ VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000
@@ -144,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.conn_id = random_string(5) # To dedupe in case of name clashes.
# List of pending commands to send once we've established the connection
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = defaultdict(int)
- self.outbound_commands_counter = defaultdict(int)
+ self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
+ self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -409,14 +415,14 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
- self.replication_streams = set()
+ self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
- self.connecting_streams = set()
+ self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
- self.pending_rdata = {}
+ self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
@@ -642,11 +648,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
- self.streams_connecting = set()
+ self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
- self.pending_batches = {}
+ self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
@@ -766,7 +772,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
op = SIOCINQ
else:
op = SIOCOUTQ
- size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
+ size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
return size
return 0
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index d1e98428bc..cbfdaf5773 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,6 +17,7 @@
import logging
import random
+from typing import List
from six import itervalues
@@ -79,7 +80,7 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
- self.connections = []
+ self.connections = [] # type: List[ServerReplicationStreamProtocol]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 8512923eae..4ab0334fc1 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,10 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
from collections import namedtuple
+from typing import Any
from twisted.internet import defer
@@ -104,8 +104,9 @@ class Stream(object):
time it was called up until the point `advance_current_token` was called.
"""
- NAME = None # The name of the stream
- ROW_TYPE = None # The type of the row. Used by the default impl of parse_row.
+ NAME = None # type: str # The name of the stream
+ # The type of the row. Used by the default impl of parse_row.
+ ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod
@@ -231,8 +232,8 @@ class BackfillStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_backfill_token
- self.update_function = store.get_all_new_backfill_event_rows
+ self.current_token = store.get_current_backfill_token # type: ignore
+ self.update_function = store.get_all_new_backfill_event_rows # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -246,8 +247,8 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
- self.current_token = store.get_current_presence_token
- self.update_function = presence_handler.get_all_presence_updates
+ self.current_token = store.get_current_presence_token # type: ignore
+ self.update_function = presence_handler.get_all_presence_updates # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -260,8 +261,8 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- self.current_token = typing_handler.get_current_token
- self.update_function = typing_handler.get_all_typing_updates
+ self.current_token = typing_handler.get_current_token # type: ignore
+ self.update_function = typing_handler.get_all_typing_updates # type: ignore
super(TypingStream, self).__init__(hs)
@@ -273,8 +274,8 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_max_receipt_stream_id
- self.update_function = store.get_all_updated_receipts
+ self.current_token = store.get_max_receipt_stream_id # type: ignore
+ self.update_function = store.get_all_updated_receipts # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -310,8 +311,8 @@ class PushersStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_pushers_stream_token
- self.update_function = store.get_all_updated_pushers_rows
+ self.current_token = store.get_pushers_stream_token # type: ignore
+ self.update_function = store.get_all_updated_pushers_rows # type: ignore
super(PushersStream, self).__init__(hs)
@@ -327,8 +328,8 @@ class CachesStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_cache_stream_token
- self.update_function = store.get_all_updated_caches
+ self.current_token = store.get_cache_stream_token # type: ignore
+ self.update_function = store.get_all_updated_caches # type: ignore
super(CachesStream, self).__init__(hs)
@@ -343,8 +344,8 @@ class PublicRoomsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_public_room_stream_id
- self.update_function = store.get_all_new_public_rooms
+ self.current_token = store.get_current_public_room_stream_id # type: ignore
+ self.update_function = store.get_all_new_public_rooms # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -360,8 +361,8 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_device_stream_token
- self.update_function = store.get_all_device_list_changes_for_remotes
+ self.current_token = store.get_device_stream_token # type: ignore
+ self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -376,8 +377,8 @@ class ToDeviceStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_to_device_stream_token
- self.update_function = store.get_all_new_device_messages
+ self.current_token = store.get_to_device_stream_token # type: ignore
+ self.update_function = store.get_all_new_device_messages # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -392,8 +393,8 @@ class TagAccountDataStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_max_account_data_stream_id
- self.update_function = store.get_all_updated_tags
+ self.current_token = store.get_max_account_data_stream_id # type: ignore
+ self.update_function = store.get_all_updated_tags # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -408,7 +409,7 @@ class AccountDataStream(Stream):
def __init__(self, hs):
self.store = hs.get_datastore()
- self.current_token = self.store.get_max_account_data_stream_id
+ self.current_token = self.store.get_max_account_data_stream_id # type: ignore
super(AccountDataStream, self).__init__(hs)
@@ -434,8 +435,8 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_group_stream_token
- self.update_function = store.get_all_groups_changes
+ self.current_token = store.get_group_stream_token # type: ignore
+ self.update_function = store.get_all_groups_changes # type: ignore
super(GroupServerStream, self).__init__(hs)
@@ -451,7 +452,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_device_stream_token
- self.update_function = store.get_all_user_signature_changes_for_remotes
+ self.current_token = store.get_device_stream_token # type: ignore
+ self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index d97669c886..0843e5aa90 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,7 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import heapq
+from typing import Tuple, Type
import attr
@@ -63,7 +65,8 @@ class BaseEventsStreamRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = None # type: str
@classmethod
def from_data(cls, data):
@@ -99,9 +102,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id = attr.ib() # str, optional
-TypeToRow = {
- Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
-}
+_EventRows = (
+ EventsStreamEventRow,
+ EventsStreamCurrentStateRow,
+) # type: Tuple[Type[BaseEventsStreamRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream):
@@ -112,7 +118,7 @@ class EventsStream(Stream):
def __init__(self, hs):
self._store = hs.get_datastore()
- self.current_token = self._store.get_current_events_token
+ self.current_token = self._store.get_current_events_token # type: ignore
super(EventsStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index dc2484109d..615f3dc9ac 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -37,7 +37,7 @@ class FederationStream(Stream):
def __init__(self, hs):
federation_sender = hs.get_federation_sender()
- self.current_token = federation_sender.get_current_token
- self.update_function = federation_sender.get_replication_rows
+ self.current_token = federation_sender.get_current_token # type: ignore
+ self.update_function = federation_sender.get_replication_rows # type: ignore
super(FederationStream, self).__init__(hs)
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index fa833e54cf..3a445d6eed 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -32,16 +32,24 @@ class QuarantineMediaInRoom(RestServlet):
this server.
"""
- PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
+ PATTERNS = (
+ historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+ +
+ # This path kept around for legacy reasons
+ historical_admin_path_patterns("/quarantine_media/(?P<room_id>![^/]+)")
+ )
def __init__(self, hs):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, room_id):
+ async def on_POST(self, request, room_id: str):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
+ logging.info("Quarantining room: %s", room_id)
+
+ # Quarantine all media in this room
num_quarantined = await self.store.quarantine_media_ids_in_room(
room_id, requester.user.to_string()
)
@@ -49,6 +57,60 @@ class QuarantineMediaInRoom(RestServlet):
return 200, {"num_quarantined": num_quarantined}
+class QuarantineMediaByUser(RestServlet):
+ """Quarantines all local media by a given user so that no one can download it via
+ this server.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/user/(?P<user_id>[^/]+)/media/quarantine"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, user_id: str):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Quarantining local media by user: %s", user_id)
+
+ # Quarantine all media this user has uploaded
+ num_quarantined = await self.store.quarantine_media_ids_by_user(
+ user_id, requester.user.to_string()
+ )
+
+ return 200, {"num_quarantined": num_quarantined}
+
+
+class QuarantineMediaByID(RestServlet):
+ """Quarantines local or remote media by a given ID so that no one can download
+ it via this server.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, server_name: str, media_id: str):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Quarantining local media by ID: %s/%s", server_name, media_id)
+
+ # Quarantine this media id
+ await self.store.quarantine_media_by_id(
+ server_name, media_id, requester.user.to_string()
+ )
+
+ return 200, {}
+
+
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
@@ -94,4 +156,6 @@ def register_servlets_for_media_repo(hs, http_server):
"""
PurgeMediaCacheRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
+ QuarantineMediaByID(hs).register(http_server)
+ QuarantineMediaByUser(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 2dac90578c..f7432c8d2f 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -105,7 +105,7 @@ class ServerNoticesManager(object):
assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
- rooms = yield self._store.get_rooms_for_user_where_membership_is(
+ rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
user_id, [Membership.INVITE, Membership.JOIN]
)
system_mxid = self._config.server_notices_mxid
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
index 092e803799..e1d03429ca 100644
--- a/synapse/storage/data_stores/__init__.py
+++ b/synapse/storage/data_stores/__init__.py
@@ -47,7 +47,7 @@ class DataStores(object):
with make_conn(database_config, engine) as db_conn:
logger.info("Preparing database %r...", db_name)
- engine.check_database(db_conn.cursor())
+ engine.check_database(db_conn)
prepare_database(
db_conn, engine, hs.config, data_stores=database_config.data_stores,
)
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 58f35d7f56..e9fe63037b 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -128,6 +128,7 @@ class EventsStore(
hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+ self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
def _read_forward_extremities(self):
@@ -547,6 +548,34 @@ class EventsStore(
],
)
+ # Note: Do we really want to delete rows here (that we do not
+ # subsequently reinsert below)? While technically correct it means
+ # we have no record of the fact the user *was* a member of the
+ # room but got, say, state reset out of it.
+ if to_delete or to_insert:
+ txn.executemany(
+ "DELETE FROM local_current_membership"
+ " WHERE room_id = ? AND user_id = ?",
+ (
+ (room_id, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ if etype == EventTypes.Member and self.is_mine_id(state_key)
+ ),
+ )
+
+ if to_insert:
+ txn.executemany(
+ """INSERT INTO local_current_membership
+ (room_id, user_id, event_id, membership)
+ VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+ """,
+ [
+ (room_id, key[1], ev_id, ev_id)
+ for key, ev_id in to_insert.items()
+ if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+ ],
+ )
+
txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
room_id,
@@ -1724,6 +1753,7 @@ class EventsStore(
"local_invites",
"room_account_data",
"room_tags",
+ "local_current_membership",
):
logger.info("[purge] removing %s from %s", room_id, table)
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 8636d75030..49bab62be3 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -18,7 +18,7 @@ import collections
import logging
import re
from abc import abstractmethod
-from typing import Optional, Tuple
+from typing import List, Optional, Tuple
from six import integer_types
@@ -399,6 +399,8 @@ class RoomWorkerStore(SQLBaseStore):
the associated media
"""
+ logger.info("Quarantining media in room: %s", room_id)
+
def _quarantine_media_in_room_txn(txn):
local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
total_media_quarantined = 0
@@ -494,6 +496,118 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
+ def quarantine_media_by_id(
+ self, server_name: str, media_id: str, quarantined_by: str,
+ ):
+ """quarantines a single local or remote media id
+
+ Args:
+ server_name: The name of the server that holds this media
+ media_id: The ID of the media to be quarantined
+ quarantined_by: The user ID that initiated the quarantine request
+ """
+ logger.info("Quarantining media: %s/%s", server_name, media_id)
+ is_local = server_name == self.config.server_name
+
+ def _quarantine_media_by_id_txn(txn):
+ local_mxcs = [media_id] if is_local else []
+ remote_mxcs = [(server_name, media_id)] if not is_local else []
+
+ return self._quarantine_media_txn(
+ txn, local_mxcs, remote_mxcs, quarantined_by
+ )
+
+ return self.db.runInteraction(
+ "quarantine_media_by_user", _quarantine_media_by_id_txn
+ )
+
+ def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+ """quarantines all local media associated with a single user
+
+ Args:
+ user_id: The ID of the user to quarantine media of
+ quarantined_by: The ID of the user who made the quarantine request
+ """
+
+ def _quarantine_media_by_user_txn(txn):
+ local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
+ return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
+
+ return self.db.runInteraction(
+ "quarantine_media_by_user", _quarantine_media_by_user_txn
+ )
+
+ def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+ """Retrieves local media IDs by a given user
+
+ Args:
+ txn (cursor)
+ user_id: The ID of the user to retrieve media IDs of
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ # Local media
+ sql = """
+ SELECT media_id
+ FROM local_media_repository
+ WHERE user_id = ?
+ """
+ if filter_quarantined:
+ sql += "AND quarantined_by IS NULL"
+ txn.execute(sql, (user_id,))
+
+ local_media_ids = [row[0] for row in txn]
+
+ # TODO: Figure out all remote media a user has referenced in a message
+
+ return local_media_ids
+
+ def _quarantine_media_txn(
+ self,
+ txn,
+ local_mxcs: List[str],
+ remote_mxcs: List[Tuple[str, str]],
+ quarantined_by: str,
+ ) -> int:
+ """Quarantine local and remote media items
+
+ Args:
+ txn (cursor)
+ local_mxcs: A list of local mxc URLs
+ remote_mxcs: A list of (remote server, media id) tuples representing
+ remote mxc URLs
+ quarantined_by: The ID of the user who initiated the quarantine request
+ Returns:
+ The total number of media items quarantined
+ """
+ total_media_quarantined = 0
+
+ # Update all the tables to set the quarantined_by flag
+ txn.executemany(
+ """
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """,
+ ((quarantined_by, media_id) for media_id in local_mxcs),
+ )
+
+ txn.executemany(
+ """
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
+ )
+
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
+
+ return total_media_quarantined
+
class RoomBackgroundUpdateStore(SQLBaseStore):
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 70ff5751b6..9acef7c950 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -297,19 +297,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return {row[0]: row[1] for row in txn}
@cached()
- def get_invited_rooms_for_user(self, user_id):
- """ Get all the rooms the user is invited to
+ def get_invited_rooms_for_local_user(self, user_id):
+ """ Get all the rooms the *local* user is invited to
+
Args:
user_id (str): The user ID.
Returns:
A deferred list of RoomsForUser.
"""
- return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
+ return self.get_rooms_for_local_user_where_membership_is(
+ user_id, [Membership.INVITE]
+ )
@defer.inlineCallbacks
- def get_invite_for_user_in_room(self, user_id, room_id):
- """Gets the invite for the given user and room
+ def get_invite_for_local_user_in_room(self, user_id, room_id):
+ """Gets the invite for the given *local* user and room
Args:
user_id (str)
@@ -319,15 +322,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Deferred: Resolves to either a RoomsForUser or None if no invite was
found.
"""
- invites = yield self.get_invited_rooms_for_user(user_id)
+ invites = yield self.get_invited_rooms_for_local_user(user_id)
for invite in invites:
if invite.room_id == room_id:
return invite
return None
@defer.inlineCallbacks
- def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
- """ Get all the rooms for this user where the membership for this user
+ def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
+ """ Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
Filters out forgotten rooms.
@@ -344,8 +347,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return defer.succeed(None)
rooms = yield self.db.runInteraction(
- "get_rooms_for_user_where_membership_is",
- self._get_rooms_for_user_where_membership_is_txn,
+ "get_rooms_for_local_user_where_membership_is",
+ self._get_rooms_for_local_user_where_membership_is_txn,
user_id,
membership_list,
)
@@ -354,76 +357,42 @@ class RoomMemberWorkerStore(EventsWorkerStore):
forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
return [room for room in rooms if room.room_id not in forgotten_rooms]
- def _get_rooms_for_user_where_membership_is_txn(
+ def _get_rooms_for_local_user_where_membership_is_txn(
self, txn, user_id, membership_list
):
+ # Paranoia check.
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
+ % (user_id,),
+ )
- do_invite = Membership.INVITE in membership_list
- membership_list = [m for m in membership_list if m != Membership.INVITE]
-
- results = []
- if membership_list:
- if self._current_state_events_membership_up_to_date:
- clause, args = make_in_list_sql_clause(
- self.database_engine, "c.membership", membership_list
- )
- sql = """
- SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND state_key = ?
- AND %s
- """ % (
- clause,
- )
- else:
- clause, args = make_in_list_sql_clause(
- self.database_engine, "m.membership", membership_list
- )
- sql = """
- SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
- FROM current_state_events AS c
- INNER JOIN room_memberships AS m USING (room_id, event_id)
- INNER JOIN events AS e USING (room_id, event_id)
- WHERE
- c.type = 'm.room.member'
- AND state_key = ?
- AND %s
- """ % (
- clause,
- )
-
- txn.execute(sql, (user_id, *args))
- results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "c.membership", membership_list
+ )
- if do_invite:
- sql = (
- "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
- " FROM local_invites as i"
- " INNER JOIN events as e USING (event_id)"
- " WHERE invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
+ sql = """
+ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+ FROM local_current_membership AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ user_id = ?
+ AND %s
+ """ % (
+ clause,
+ )
- txn.execute(sql, (user_id,))
- results.extend(
- RoomsForUser(
- room_id=r["room_id"],
- sender=r["inviter"],
- event_id=r["event_id"],
- stream_ordering=r["stream_ordering"],
- membership=Membership.INVITE,
- )
- for r in self.db.cursor_to_dict(txn)
- )
+ txn.execute(sql, (user_id, *args))
+ results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
return results
- @cachedInlineCallbacks(max_entries=500000, iterable=True)
+ @cached(max_entries=500000, iterable=True)
def get_rooms_for_user_with_stream_ordering(self, user_id):
- """Returns a set of room_ids the user is currently joined to
+ """Returns a set of room_ids the user is currently joined to.
+
+ If a remote user only returns rooms this server is currently
+ participating in.
Args:
user_id (str)
@@ -433,17 +402,49 @@ class RoomMemberWorkerStore(EventsWorkerStore):
the rooms the user is in currently, along with the stream ordering
of the most recent join for that user and room.
"""
- rooms = yield self.get_rooms_for_user_where_membership_is(
- user_id, membership_list=[Membership.JOIN]
- )
- return frozenset(
- GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
- for r in rooms
+ return self.db.runInteraction(
+ "get_rooms_for_user_with_stream_ordering",
+ self._get_rooms_for_user_with_stream_ordering_txn,
+ user_id,
)
+ def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+ # We use `current_state_events` here and not `local_current_membership`
+ # as a) this gets called with remote users and b) this only gets called
+ # for rooms the server is participating in.
+ if self._current_state_events_membership_up_to_date:
+ sql = """
+ SELECT room_id, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND state_key = ?
+ AND c.membership = ?
+ """
+ else:
+ sql = """
+ SELECT room_id, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN room_memberships AS m USING (room_id, event_id)
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND state_key = ?
+ AND m.membership = ?
+ """
+
+ txn.execute(sql, (user_id, Membership.JOIN))
+ results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+
+ return results
+
@defer.inlineCallbacks
def get_rooms_for_user(self, user_id, on_invalidate=None):
- """Returns a set of room_ids the user is currently joined to
+ """Returns a set of room_ids the user is currently joined to.
+
+ If a remote user only returns rooms this server is currently
+ participating in.
"""
rooms = yield self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
@@ -1022,7 +1023,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
event.internal_metadata.stream_ordering,
)
txn.call_after(
- self.get_invited_rooms_for_user.invalidate, (event.state_key,)
+ self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
@@ -1064,6 +1065,27 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
),
)
+ # We also update the `local_current_membership` table with
+ # latest invite info. This will usually get updated by the
+ # `current_state_events` handling, unless its an outlier.
+ if event.internal_metadata.is_outlier():
+ # This should only happen for out of band memberships, so
+ # we add a paranoia check.
+ assert event.internal_metadata.is_out_of_band_membership()
+
+ self.db.simple_upsert_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={
+ "room_id": event.room_id,
+ "user_id": event.state_key,
+ },
+ values={
+ "event_id": event.event_id,
+ "membership": event.membership,
+ },
+ )
+
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
@@ -1075,6 +1097,15 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
+ # We also clear this entry from `local_current_membership`.
+ # Ideally we'd point to a leave event, but we don't have one, so
+ # nevermind.
+ self.db.simple_delete_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ )
+
with self._stream_id_gen.get_next() as stream_ordering:
yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
new file mode 100644
index 0000000000..601c236c4a
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# We create a new table called `local_current_membership` that stores the latest
+# membership state of local users in rooms, which helps track leaves/bans/etc
+# even if the server has left the room (and so has deleted the room from
+# `current_state_events`). This will also include outstanding invites for local
+# users for rooms the server isn't in.
+#
+# If the server isn't and hasn't been in the room then it will only include
+# outsstanding invites, and not e.g. pre-emptive bans of local users.
+#
+# If the server later rejoins a room `local_current_membership` can simply be
+# replaced with the new current state of the room (which results in the
+# equivalent behaviour as if the server had remained in the room).
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+ # We need to do the insert in `run_upgrade` section as we don't have access
+ # to `config` in `run_create`.
+
+ # This upgrade may take a bit of time for large servers (e.g. one minute for
+ # matrix.org) but means we avoid a lots of book keeping required to do it as
+ # a background update.
+
+ # We check if the `current_state_events.membership` is up to date by
+ # checking if the relevant background update has finished. If it has
+ # finished we can avoid doing a join against `room_memberships`, which
+ # speesd things up.
+ cur.execute(
+ """SELECT 1 FROM background_updates
+ WHERE update_name = 'current_state_events_membership'
+ """
+ )
+ current_state_membership_up_to_date = not bool(cur.fetchone())
+
+ # Cheekily drop and recreate indices, as that is faster.
+ cur.execute("DROP INDEX local_current_membership_idx")
+ cur.execute("DROP INDEX local_current_membership_room_idx")
+
+ if current_state_membership_up_to_date:
+ sql = """
+ INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+ SELECT c.room_id, state_key AS user_id, event_id, c.membership
+ FROM current_state_events AS c
+ WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key like '%' || ?
+ """
+ else:
+ # We can't rely on the membership column, so we need to join against
+ # `room_memberships`.
+ sql = """
+ INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+ SELECT c.room_id, state_key AS user_id, event_id, r.membership
+ FROM current_state_events AS c
+ INNER JOIN room_memberships AS r USING (event_id)
+ WHERE type = 'm.room.member' and state_key like '%' || ?
+ """
+ cur.execute(sql, (config.server_name,))
+
+ cur.execute(
+ "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+ )
+ cur.execute(
+ "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+ )
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ cur.execute(
+ """
+ CREATE TABLE local_current_membership (
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ membership TEXT NOT NULL
+ )"""
+ )
+
+ cur.execute(
+ "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+ )
+ cur.execute(
+ "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+ )
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index b7c4eda338..c84cb452b0 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -32,20 +32,7 @@ class PostgresEngine(object):
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
- def check_database(self, txn):
- txn.execute("SHOW SERVER_ENCODING")
- rows = txn.fetchall()
- if rows and rows[0][0] != "UTF8":
- raise IncorrectDatabaseSetup(
- "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
- "See docs/postgres.rst for more information." % (rows[0][0],)
- )
-
- def convert_param_style(self, sql):
- return sql.replace("?", "%s")
-
- def on_new_connection(self, db_conn):
-
+ def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
@@ -53,9 +40,22 @@ class PostgresEngine(object):
self._version = db_conn.server_version
# Are we on a supported PostgreSQL version?
- if self._version < 90500:
+ if not allow_outdated_version and self._version < 90500:
raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
+ with db_conn.cursor() as txn:
+ txn.execute("SHOW SERVER_ENCODING")
+ rows = txn.fetchall()
+ if rows and rows[0][0] != "UTF8":
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
+ "See docs/postgres.rst for more information." % (rows[0][0],)
+ )
+
+ def convert_param_style(self, sql):
+ return sql.replace("?", "%s")
+
+ def on_new_connection(self, db_conn):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
@@ -119,8 +119,8 @@ class PostgresEngine(object):
Returns:
string
"""
- # note that this is a bit of a hack because it relies on on_new_connection
- # having been called at least once. Still, that should be a safe bet here.
+ # note that this is a bit of a hack because it relies on check_database
+ # having been called. Still, that should be a safe bet here.
numver = self._version
assert numver is not None
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index df039a072d..cbf52f5191 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -53,8 +53,11 @@ class Sqlite3Engine(object):
"""
return False
- def check_database(self, txn):
- pass
+ def check_database(self, db_conn, allow_outdated_version: bool = False):
+ if not allow_outdated_version:
+ version = self.module.sqlite_version_info
+ if version < (3, 11, 0):
+ raise RuntimeError("Synapse requires sqlite 3.11 or above.")
def convert_param_style(self, sql):
return sql
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e70026b80a..e86984cd50 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 56
+SCHEMA_VERSION = 57
dir_path = os.path.abspath(os.path.dirname(__file__))
|