diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/config/server.py | 38 | ||||
-rw-r--r-- | synapse/events/utils.py | 44 | ||||
-rw-r--r-- | synapse/handlers/events.py | 8 | ||||
-rw-r--r-- | synapse/handlers/initial_sync.py | 44 | ||||
-rw-r--r-- | synapse/handlers/message.py | 7 | ||||
-rw-r--r-- | synapse/handlers/pagination.py | 22 | ||||
-rw-r--r-- | synapse/handlers/search.py | 42 | ||||
-rw-r--r-- | synapse/http/client.py | 6 | ||||
-rw-r--r-- | synapse/http/matrixfederationclient.py | 48 | ||||
-rw-r--r-- | synapse/rest/client/v1/events.py | 5 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 29 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/notifications.py | 10 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/sync.py | 59 | ||||
-rw-r--r-- | synapse/rest/media/v1/storage_provider.py | 1 | ||||
-rw-r--r-- | synapse/server.py | 5 | ||||
-rw-r--r-- | synapse/storage/appservice.py | 4 | ||||
-rw-r--r-- | synapse/storage/event_federation.py | 6 | ||||
-rw-r--r-- | synapse/storage/events_worker.py | 50 | ||||
-rw-r--r-- | synapse/storage/search.py | 4 | ||||
-rw-r--r-- | synapse/storage/stream.py | 18 | ||||
-rw-r--r-- | synapse/util/async_helpers.py | 19 | ||||
-rw-r--r-- | synapse/util/distributor.py | 1 |
22 files changed, 326 insertions, 144 deletions
diff --git a/synapse/config/server.py b/synapse/config/server.py index 8dce75c56a..7874cd9da7 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -17,6 +17,8 @@ import logging import os.path +from netaddr import IPSet + from synapse.http.endpoint import parse_and_validate_server_name from synapse.python_dependencies import DependencyException, check_requirements @@ -137,6 +139,24 @@ class ServerConfig(Config): for domain in federation_domain_whitelist: self.federation_domain_whitelist[domain] = True + self.federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", [], + ) + + # Attempt to create an IPSet from the given ranges + try: + self.federation_ip_range_blacklist = IPSet( + self.federation_ip_range_blacklist + ) + + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + except Exception as e: + raise ConfigError( + "Invalid range(s) provided in " + "federation_ip_range_blacklist: %s" % e + ) + if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': self.public_baseurl += '/' @@ -386,6 +406,24 @@ class ServerConfig(Config): # - nyc.example.com # - syd.example.com + # Prevent federation requests from being sent to the following + # blacklist IP address CIDR ranges. If this option is not specified, or + # specified with an empty list, no ip range blacklist will be enforced. + # + # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly + # listed here, since they correspond to unroutable addresses.) + # + federation_ip_range_blacklist: + - '127.0.0.0/8' + - '10.0.0.0/8' + - '172.16.0.0/12' + - '192.168.0.0/16' + - '100.64.0.0/10' + - '169.254.0.0/16' + - '::1/128' + - 'fe80::/64' + - 'fc00::/7' + # List of ports that Synapse should listen on, their purpose and their # configuration. # diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 07fccdd8f9..a5454556cc 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -19,7 +19,10 @@ from six import string_types from frozendict import frozendict +from twisted.internet import defer + from synapse.api.constants import EventTypes +from synapse.util.async_helpers import yieldable_gather_results from . import EventBase @@ -311,3 +314,44 @@ def serialize_event(e, time_now_ms, as_client_event=True, d = only_fields(d, only_event_fields) return d + + +class EventClientSerializer(object): + """Serializes events that are to be sent to clients. + + This is used for bundling extra information with any events to be sent to + clients. + """ + + def __init__(self, hs): + pass + + def serialize_event(self, event, time_now, **kwargs): + """Serializes a single event. + + Args: + event (EventBase) + time_now (int): The current time in milliseconds + **kwargs: Arguments to pass to `serialize_event` + + Returns: + Deferred[dict]: The serialized event + """ + event = serialize_event(event, time_now, **kwargs) + return defer.succeed(event) + + def serialize_events(self, events, time_now, **kwargs): + """Serializes multiple events. + + Args: + event (iter[EventBase]) + time_now (int): The current time in milliseconds + **kwargs: Arguments to pass to `serialize_event` + + Returns: + Deferred[list[dict]]: The list of serialized events + """ + return yieldable_gather_results( + self.serialize_event, events, + time_now=time_now, **kwargs + ) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 1b4d8c74ae..6003ad9cca 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -21,7 +21,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, SynapseError from synapse.events import EventBase -from synapse.events.utils import serialize_event from synapse.types import UserID from synapse.util.logutils import log_function from synapse.visibility import filter_events_for_client @@ -50,6 +49,7 @@ class EventStreamHandler(BaseHandler): self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self._server_notices_sender = hs.get_server_notices_sender() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks @log_function @@ -120,9 +120,9 @@ class EventStreamHandler(BaseHandler): time_now = self.clock.time_msec() - chunks = [ - serialize_event(e, time_now, as_client_event) for e in events - ] + chunks = yield self._event_serializer.serialize_events( + events, time_now, as_client_event=as_client_event, + ) chunk = { "chunk": chunks, diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 7dfae78db0..aaee5db0b7 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -19,7 +19,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.handlers.presence import format_user_presence_state from synapse.streams.config import PaginationConfig @@ -43,6 +42,7 @@ class InitialSyncHandler(BaseHandler): self.clock = hs.get_clock() self.validator = EventValidator() self.snapshot_cache = SnapshotCache() + self._event_serializer = hs.get_event_client_serializer() def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True, include_archived=False): @@ -138,7 +138,9 @@ class InitialSyncHandler(BaseHandler): d["inviter"] = event.sender invite_event = yield self.store.get_event(event.event_id) - d["invite"] = serialize_event(invite_event, time_now, as_client_event) + d["invite"] = yield self._event_serializer.serialize_event( + invite_event, time_now, as_client_event, + ) rooms_ret.append(d) @@ -185,18 +187,21 @@ class InitialSyncHandler(BaseHandler): time_now = self.clock.time_msec() d["messages"] = { - "chunk": [ - serialize_event(m, time_now, as_client_event) - for m in messages - ], + "chunk": ( + yield self._event_serializer.serialize_events( + messages, time_now=time_now, + as_client_event=as_client_event, + ) + ), "start": start_token.to_string(), "end": end_token.to_string(), } - d["state"] = [ - serialize_event(c, time_now, as_client_event) - for c in current_state.values() - ] + d["state"] = yield self._event_serializer.serialize_events( + current_state.values(), + time_now=time_now, + as_client_event=as_client_event + ) account_data_events = [] tags = tags_by_room.get(event.room_id) @@ -337,11 +342,15 @@ class InitialSyncHandler(BaseHandler): "membership": membership, "room_id": room_id, "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], + "chunk": (yield self._event_serializer.serialize_events( + messages, time_now, + )), "start": start_token.to_string(), "end": end_token.to_string(), }, - "state": [serialize_event(s, time_now) for s in room_state.values()], + "state": (yield self._event_serializer.serialize_events( + room_state.values(), time_now, + )), "presence": [], "receipts": [], }) @@ -355,10 +364,9 @@ class InitialSyncHandler(BaseHandler): # TODO: These concurrently time_now = self.clock.time_msec() - state = [ - serialize_event(x, time_now) - for x in current_state.values() - ] + state = yield self._event_serializer.serialize_events( + current_state.values(), time_now, + ) now_token = yield self.hs.get_event_sources().get_current_token() @@ -425,7 +433,9 @@ class InitialSyncHandler(BaseHandler): ret = { "room_id": room_id, "messages": { - "chunk": [serialize_event(m, time_now) for m in messages], + "chunk": (yield self._event_serializer.serialize_events( + messages, time_now, + )), "start": start_token.to_string(), "end": end_token.to_string(), }, diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e5afeadf68..7b2c33a922 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -32,7 +32,6 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import RoomVersions from synapse.api.urls import ConsentURIBuilder -from synapse.events.utils import serialize_event from synapse.events.validator import EventValidator from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.state import StateFilter @@ -57,6 +56,7 @@ class MessageHandler(object): self.clock = hs.get_clock() self.state = hs.get_state_handler() self.store = hs.get_datastore() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def get_room_data(self, user_id=None, room_id=None, @@ -164,9 +164,10 @@ class MessageHandler(object): room_state = room_state[membership_event_id] now = self.clock.time_msec() - defer.returnValue( - [serialize_event(c, now) for c in room_state.values()] + events = yield self._event_serializer.serialize_events( + room_state.values(), now, ) + defer.returnValue(events) @defer.inlineCallbacks def get_joined_members(self, requester, room_id): diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index e4fdae9266..8f811e24fe 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -20,7 +20,6 @@ from twisted.python.failure import Failure from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError -from synapse.events.utils import serialize_event from synapse.storage.state import StateFilter from synapse.types import RoomStreamToken from synapse.util.async_helpers import ReadWriteLock @@ -78,6 +77,7 @@ class PaginationHandler(object): self._purges_in_progress_by_room = set() # map from purge id to PurgeStatus self._purges_by_id = {} + self._event_serializer = hs.get_event_client_serializer() def start_purge_history(self, room_id, token, delete_local_events=False): @@ -278,18 +278,22 @@ class PaginationHandler(object): time_now = self.clock.time_msec() chunk = { - "chunk": [ - serialize_event(e, time_now, as_client_event) - for e in events - ], + "chunk": ( + yield self._event_serializer.serialize_events( + events, time_now, + as_client_event=as_client_event, + ) + ), "start": pagin_config.from_token.to_string(), "end": next_token.to_string(), } if state: - chunk["state"] = [ - serialize_event(e, time_now, as_client_event) - for e in state - ] + chunk["state"] = ( + yield self._event_serializer.serialize_events( + state, time_now, + as_client_event=as_client_event, + ) + ) defer.returnValue(chunk) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 49c439313e..9bba74d6c9 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -23,7 +23,6 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import SynapseError from synapse.api.filtering import Filter -from synapse.events.utils import serialize_event from synapse.storage.state import StateFilter from synapse.visibility import filter_events_for_client @@ -36,6 +35,7 @@ class SearchHandler(BaseHandler): def __init__(self, hs): super(SearchHandler, self).__init__(hs) + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def get_old_rooms_from_upgraded_room(self, room_id): @@ -401,14 +401,16 @@ class SearchHandler(BaseHandler): time_now = self.clock.time_msec() for context in contexts.values(): - context["events_before"] = [ - serialize_event(e, time_now) - for e in context["events_before"] - ] - context["events_after"] = [ - serialize_event(e, time_now) - for e in context["events_after"] - ] + context["events_before"] = ( + yield self._event_serializer.serialize_events( + context["events_before"], time_now, + ) + ) + context["events_after"] = ( + yield self._event_serializer.serialize_events( + context["events_after"], time_now, + ) + ) state_results = {} if include_state: @@ -422,14 +424,13 @@ class SearchHandler(BaseHandler): # We're now about to serialize the events. We should not make any # blocking calls after this. Otherwise the 'age' will be wrong - results = [ - { + results = [] + for e in allowed_events: + results.append({ "rank": rank_map[e.event_id], - "result": serialize_event(e, time_now), + "result": (yield self._event_serializer.serialize_event(e, time_now)), "context": contexts.get(e.event_id, {}), - } - for e in allowed_events - ] + }) rooms_cat_res = { "results": results, @@ -438,10 +439,13 @@ class SearchHandler(BaseHandler): } if state_results: - rooms_cat_res["state"] = { - room_id: [serialize_event(e, time_now) for e in state] - for room_id, state in state_results.items() - } + s = {} + for room_id, state in state_results.items(): + s[room_id] = yield self._event_serializer.serialize_events( + state, time_now, + ) + + rooms_cat_res["state"] = s if room_groups and "room_id" in group_keys: rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups diff --git a/synapse/http/client.py b/synapse/http/client.py index ddbfb72228..77fe68818b 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -165,7 +165,8 @@ class BlacklistingAgentWrapper(Agent): ip_address, self._ip_whitelist, self._ip_blacklist ): logger.info( - "Blocking access to %s because of blacklist" % (ip_address,) + "Blocking access to %s due to blacklist" % + (ip_address,) ) e = SynapseError(403, "IP address blocked by IP blacklist entry") return defer.fail(Failure(e)) @@ -263,9 +264,6 @@ class SimpleHttpClient(object): uri (str): URI to query. data (bytes): Data to send in the request body, if applicable. headers (t.w.http_headers.Headers): Request headers. - - Raises: - SynapseError: If the IP is blacklisted. """ # A small wrapper around self.agent.request() so we can easily attach # counters to it diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index ff63d0b2a8..7eefc7b1fc 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -27,9 +27,11 @@ import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json +from zope.interface import implementer from twisted.internet import defer, protocol from twisted.internet.error import DNSLookupError +from twisted.internet.interfaces import IReactorPluggableNameResolver from twisted.internet.task import _EPSILON, Cooperator from twisted.web._newclient import ResponseDone from twisted.web.http_headers import Headers @@ -44,6 +46,7 @@ from synapse.api.errors import ( SynapseError, ) from synapse.http import QuieterFileBodyProducer +from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.util.async_helpers import timeout_deferred from synapse.util.logcontext import make_deferred_yieldable @@ -172,19 +175,44 @@ class MatrixFederationHttpClient(object): self.hs = hs self.signing_key = hs.config.signing_key[0] self.server_name = hs.hostname - reactor = hs.get_reactor() + + real_reactor = hs.get_reactor() + + # We need to use a DNS resolver which filters out blacklisted IP + # addresses, to prevent DNS rebinding. + nameResolver = IPBlacklistingResolver( + real_reactor, None, hs.config.federation_ip_range_blacklist, + ) + + @implementer(IReactorPluggableNameResolver) + class Reactor(object): + def __getattr__(_self, attr): + if attr == "nameResolver": + return nameResolver + else: + return getattr(real_reactor, attr) + + self.reactor = Reactor() self.agent = MatrixFederationAgent( - hs.get_reactor(), + self.reactor, tls_client_options_factory, ) + + # Use a BlacklistingAgentWrapper to prevent circumventing the IP + # blacklist via IP literals in server names + self.agent = BlacklistingAgentWrapper( + self.agent, self.reactor, + ip_blacklist=hs.config.federation_ip_range_blacklist, + ) + self.clock = hs.get_clock() self._store = hs.get_datastore() self.version_string_bytes = hs.version_string.encode('ascii') self.default_timeout = 60 def schedule(x): - reactor.callLater(_EPSILON, x) + self.reactor.callLater(_EPSILON, x) self._cooperator = Cooperator(scheduler=schedule) @@ -370,7 +398,7 @@ class MatrixFederationHttpClient(object): request_deferred = timeout_deferred( request_deferred, timeout=_sec_timeout, - reactor=self.hs.get_reactor(), + reactor=self.reactor, ) response = yield request_deferred @@ -397,7 +425,7 @@ class MatrixFederationHttpClient(object): d = timeout_deferred( d, timeout=_sec_timeout, - reactor=self.hs.get_reactor(), + reactor=self.reactor, ) try: @@ -586,7 +614,7 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.hs.get_reactor(), self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response, ) defer.returnValue(body) @@ -645,7 +673,7 @@ class MatrixFederationHttpClient(object): _sec_timeout = self.default_timeout body = yield _handle_json_response( - self.hs.get_reactor(), _sec_timeout, request, response, + self.reactor, _sec_timeout, request, response, ) defer.returnValue(body) @@ -704,7 +732,7 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.hs.get_reactor(), self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response, ) defer.returnValue(body) @@ -753,7 +781,7 @@ class MatrixFederationHttpClient(object): ) body = yield _handle_json_response( - self.hs.get_reactor(), self.default_timeout, request, response, + self.reactor, self.default_timeout, request, response, ) defer.returnValue(body) @@ -801,7 +829,7 @@ class MatrixFederationHttpClient(object): try: d = _readBodyToFile(response, output_stream, max_size) - d.addTimeout(self.default_timeout, self.hs.get_reactor()) + d.addTimeout(self.default_timeout, self.reactor) length = yield make_deferred_yieldable(d) except Exception as e: logger.warn( diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index cd9b3bdbd1..c3b0a39ab7 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -19,7 +19,6 @@ import logging from twisted.internet import defer from synapse.api.errors import SynapseError -from synapse.events.utils import serialize_event from synapse.streams.config import PaginationConfig from .base import ClientV1RestServlet, client_path_patterns @@ -84,6 +83,7 @@ class EventRestServlet(ClientV1RestServlet): super(EventRestServlet, self).__init__(hs) self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def on_GET(self, request, event_id): @@ -92,7 +92,8 @@ class EventRestServlet(ClientV1RestServlet): time_now = self.clock.time_msec() if event: - defer.returnValue((200, serialize_event(event, time_now))) + event = yield self._event_serializer.serialize_event(event, time_now) + defer.returnValue((200, event)) else: defer.returnValue((404, "Event not found.")) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index fab04965cb..255a85c588 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -26,7 +26,7 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, Membership from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.filtering import Filter -from synapse.events.utils import format_event_for_client_v2, serialize_event +from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( assert_params_in_dict, parse_integer, @@ -537,6 +537,7 @@ class RoomEventServlet(ClientV1RestServlet): super(RoomEventServlet, self).__init__(hs) self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -545,7 +546,8 @@ class RoomEventServlet(ClientV1RestServlet): time_now = self.clock.time_msec() if event: - defer.returnValue((200, serialize_event(event, time_now))) + event = yield self._event_serializer.serialize_event(event, time_now) + defer.returnValue((200, event)) else: defer.returnValue((404, "Event not found.")) @@ -559,6 +561,7 @@ class RoomEventContextServlet(ClientV1RestServlet): super(RoomEventContextServlet, self).__init__(hs) self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -588,16 +591,18 @@ class RoomEventContextServlet(ClientV1RestServlet): ) time_now = self.clock.time_msec() - results["events_before"] = [ - serialize_event(event, time_now) for event in results["events_before"] - ] - results["event"] = serialize_event(results["event"], time_now) - results["events_after"] = [ - serialize_event(event, time_now) for event in results["events_after"] - ] - results["state"] = [ - serialize_event(event, time_now) for event in results["state"] - ] + results["events_before"] = yield self._event_serializer.serialize_events( + results["events_before"], time_now, + ) + results["event"] = yield self._event_serializer.serialize_event( + results["event"], time_now, + ) + results["events_after"] = yield self._event_serializer.serialize_events( + results["events_after"], time_now, + ) + results["state"] = yield self._event_serializer.serialize_events( + results["state"], time_now, + ) defer.returnValue((200, results)) diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 2a6ea3df5f..0a1eb0ae45 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -17,10 +17,7 @@ import logging from twisted.internet import defer -from synapse.events.utils import ( - format_event_for_client_v2_without_room_id, - serialize_event, -) +from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.servlet import RestServlet, parse_integer, parse_string from ._base import client_v2_patterns @@ -36,6 +33,7 @@ class NotificationsServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def on_GET(self, request): @@ -69,11 +67,11 @@ class NotificationsServlet(RestServlet): "profile_tag": pa["profile_tag"], "actions": pa["actions"], "ts": pa["received_ts"], - "event": serialize_event( + "event": (yield self._event_serializer.serialize_event( notif_events[pa["event_id"]], self.clock.time_msec(), event_format=format_event_for_client_v2_without_room_id, - ), + )), } if pa["room_id"] not in receipts_by_room: diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index 39d157a44b..c701e534e7 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -26,7 +26,6 @@ from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection from synapse.events.utils import ( format_event_for_client_v2_without_room_id, format_event_raw, - serialize_event, ) from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig @@ -86,6 +85,7 @@ class SyncRestServlet(RestServlet): self.filtering = hs.get_filtering() self.presence_handler = hs.get_presence_handler() self._server_notices_sender = hs.get_server_notices_sender() + self._event_serializer = hs.get_event_client_serializer() @defer.inlineCallbacks def on_GET(self, request): @@ -168,14 +168,14 @@ class SyncRestServlet(RestServlet): ) time_now = self.clock.time_msec() - response_content = self.encode_response( + response_content = yield self.encode_response( time_now, sync_result, requester.access_token_id, filter ) defer.returnValue((200, response_content)) - @staticmethod - def encode_response(time_now, sync_result, access_token_id, filter): + @defer.inlineCallbacks + def encode_response(self, time_now, sync_result, access_token_id, filter): if filter.event_format == 'client': event_formatter = format_event_for_client_v2_without_room_id elif filter.event_format == 'federation': @@ -183,24 +183,24 @@ class SyncRestServlet(RestServlet): else: raise Exception("Unknown event format %s" % (filter.event_format, )) - joined = SyncRestServlet.encode_joined( + joined = yield self.encode_joined( sync_result.joined, time_now, access_token_id, filter.event_fields, event_formatter, ) - invited = SyncRestServlet.encode_invited( + invited = yield self.encode_invited( sync_result.invited, time_now, access_token_id, event_formatter, ) - archived = SyncRestServlet.encode_archived( + archived = yield self.encode_archived( sync_result.archived, time_now, access_token_id, filter.event_fields, event_formatter, ) - return { + defer.returnValue({ "account_data": {"events": sync_result.account_data}, "to_device": {"events": sync_result.to_device}, "device_lists": { @@ -222,7 +222,7 @@ class SyncRestServlet(RestServlet): }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, "next_batch": sync_result.next_batch.to_string(), - } + }) @staticmethod def encode_presence(events, time_now): @@ -239,8 +239,8 @@ class SyncRestServlet(RestServlet): ] } - @staticmethod - def encode_joined(rooms, time_now, token_id, event_fields, event_formatter): + @defer.inlineCallbacks + def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter): """ Encode the joined rooms in a sync result @@ -261,15 +261,15 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = SyncRestServlet.encode_room( + joined[room.room_id] = yield self.encode_room( room, time_now, token_id, joined=True, only_fields=event_fields, event_formatter=event_formatter, ) - return joined + defer.returnValue(joined) - @staticmethod - def encode_invited(rooms, time_now, token_id, event_formatter): + @defer.inlineCallbacks + def encode_invited(self, rooms, time_now, token_id, event_formatter): """ Encode the invited rooms in a sync result @@ -289,7 +289,7 @@ class SyncRestServlet(RestServlet): """ invited = {} for room in rooms: - invite = serialize_event( + invite = yield self._event_serializer.serialize_event( room.invite, time_now, token_id=token_id, event_format=event_formatter, is_invite=True, @@ -302,10 +302,10 @@ class SyncRestServlet(RestServlet): "invite_state": {"events": invited_state} } - return invited + defer.returnValue(invited) - @staticmethod - def encode_archived(rooms, time_now, token_id, event_fields, event_formatter): + @defer.inlineCallbacks + def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter): """ Encode the archived rooms in a sync result @@ -326,17 +326,17 @@ class SyncRestServlet(RestServlet): """ joined = {} for room in rooms: - joined[room.room_id] = SyncRestServlet.encode_room( + joined[room.room_id] = yield self.encode_room( room, time_now, token_id, joined=False, only_fields=event_fields, event_formatter=event_formatter, ) - return joined + defer.returnValue(joined) - @staticmethod + @defer.inlineCallbacks def encode_room( - room, time_now, token_id, joined, + self, room, time_now, token_id, joined, only_fields, event_formatter, ): """ @@ -355,9 +355,10 @@ class SyncRestServlet(RestServlet): Returns: dict[str, object]: the room, encoded in our response format """ - def serialize(event): - return serialize_event( - event, time_now, token_id=token_id, + def serialize(events): + return self._event_serializer.serialize_events( + events, time_now=time_now, + token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, ) @@ -376,8 +377,8 @@ class SyncRestServlet(RestServlet): event.event_id, room.room_id, event.room_id, ) - serialized_state = [serialize(e) for e in state_events] - serialized_timeline = [serialize(e) for e in timeline_events] + serialized_state = yield serialize(state_events) + serialized_timeline = yield serialize(timeline_events) account_data = room.account_data @@ -397,7 +398,7 @@ class SyncRestServlet(RestServlet): result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary - return result + defer.returnValue(result) def register_servlets(hs, http_server): diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 5aa03031f6..d90cbfb56a 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -108,6 +108,7 @@ class FileStorageProviderBackend(StorageProvider): """ def __init__(self, hs, config): + self.hs = hs self.cache_directory = hs.config.media_store_path self.base_directory = config diff --git a/synapse/server.py b/synapse/server.py index 8c30ac2fa5..80d40b9272 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -35,6 +35,7 @@ from synapse.crypto import context_factory from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory from synapse.events.spamcheck import SpamChecker +from synapse.events.utils import EventClientSerializer from synapse.federation.federation_client import FederationClient from synapse.federation.federation_server import ( FederationHandlerRegistry, @@ -185,6 +186,7 @@ class HomeServer(object): 'sendmail', 'registration_handler', 'account_validity_handler', + 'event_client_serializer', ] REQUIRED_ON_MASTER_STARTUP = [ @@ -511,6 +513,9 @@ class HomeServer(object): def build_account_validity_handler(self): return AccountValidityHandler(self) + def build_event_client_serializer(self): + return EventClientSerializer(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 6092f600ba..eb329ebd8b 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -302,7 +302,7 @@ class ApplicationServiceTransactionWorkerStore( event_ids = json.loads(entry["event_ids"]) - events = yield self._get_events(event_ids) + events = yield self.get_events_as_list(event_ids) defer.returnValue( AppServiceTransaction(service=service, id=entry["txn_id"], events=events) @@ -358,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore( "get_new_events_for_appservice", get_new_events_for_appservice_txn ) - events = yield self._get_events(event_ids) + events = yield self.get_events_as_list(event_ids) defer.returnValue((upper_bound, events)) diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py index 956f876572..09e39c2c28 100644 --- a/synapse/storage/event_federation.py +++ b/synapse/storage/event_federation.py @@ -45,7 +45,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas """ return self.get_auth_chain_ids( event_ids, include_given=include_given - ).addCallback(self._get_events) + ).addCallback(self.get_events_as_list) def get_auth_chain_ids(self, event_ids, include_given=False): """Get auth events for given event_ids. The events *must* be state events. @@ -316,7 +316,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas event_list, limit, ) - .addCallback(self._get_events) + .addCallback(self.get_events_as_list) .addCallback(lambda l: sorted(l, key=lambda e: -e.depth)) ) @@ -382,7 +382,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas latest_events, limit, ) - events = yield self._get_events(ids) + events = yield self.get_events_as_list(ids) defer.returnValue(events) def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit): diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py index 663991a9b6..adc6cf26b5 100644 --- a/synapse/storage/events_worker.py +++ b/synapse/storage/events_worker.py @@ -103,7 +103,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred : A FrozenEvent. """ - events = yield self._get_events( + events = yield self.get_events_as_list( [event_id], check_redacted=check_redacted, get_prev_content=get_prev_content, @@ -142,7 +142,7 @@ class EventsWorkerStore(SQLBaseStore): Returns: Deferred : Dict from event_id to event. """ - events = yield self._get_events( + events = yield self.get_events_as_list( event_ids, check_redacted=check_redacted, get_prev_content=get_prev_content, @@ -152,13 +152,32 @@ class EventsWorkerStore(SQLBaseStore): defer.returnValue({e.event_id: e for e in events}) @defer.inlineCallbacks - def _get_events( + def get_events_as_list( self, event_ids, check_redacted=True, get_prev_content=False, allow_rejected=False, ): + """Get events from the database and return in a list in the same order + as given by `event_ids` arg. + + Args: + event_ids (list): The event_ids of the events to fetch + check_redacted (bool): If True, check if event has been redacted + and redact it. + get_prev_content (bool): If True and event is a state event, + include the previous states content in the unsigned field. + allow_rejected (bool): If True return rejected events. + + Returns: + Deferred[list[EventBase]]: List of events fetched from the database. The + events are in the same order as `event_ids` arg. + + Note that the returned list may be smaller than the list of event + IDs if not all events could be fetched. + """ + if not event_ids: defer.returnValue([]) @@ -202,21 +221,22 @@ class EventsWorkerStore(SQLBaseStore): # # The problem is that we end up at this point when an event # which has been redacted is pulled out of the database by - # _enqueue_events, because _enqueue_events needs to check the - # redaction before it can cache the redacted event. So obviously, - # calling get_event to get the redacted event out of the database - # gives us an infinite loop. + # _enqueue_events, because _enqueue_events needs to check + # the redaction before it can cache the redacted event. So + # obviously, calling get_event to get the redacted event out + # of the database gives us an infinite loop. # - # For now (quick hack to fix during 0.99 release cycle), we just - # go and fetch the relevant row from the db, but it would be nice - # to think about how we can cache this rather than hit the db - # every time we access a redaction event. + # For now (quick hack to fix during 0.99 release cycle), we + # just go and fetch the relevant row from the db, but it + # would be nice to think about how we can cache this rather + # than hit the db every time we access a redaction event. # # One thought on how to do this: - # 1. split _get_events up so that it is divided into (a) get the - # rawish event from the db/cache, (b) do the redaction/rejection - # filtering - # 2. have _get_event_from_row just call the first half of that + # 1. split get_events_as_list up so that it is divided into + # (a) get the rawish event from the db/cache, (b) do the + # redaction/rejection filtering + # 2. have _get_event_from_row just call the first half of + # that orig_sender = yield self._simple_select_one_onecol( table="events", diff --git a/synapse/storage/search.py b/synapse/storage/search.py index 226f8f1b7e..ff49eaae02 100644 --- a/synapse/storage/search.py +++ b/synapse/storage/search.py @@ -460,7 +460,7 @@ class SearchStore(BackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self._get_events([r["event_id"] for r in results]) + events = yield self.get_events_as_list([r["event_id"] for r in results]) event_map = {ev.event_id: ev for ev in events} @@ -605,7 +605,7 @@ class SearchStore(BackgroundUpdateStore): results = list(filter(lambda row: row["room_id"] in room_ids, results)) - events = yield self._get_events([r["event_id"] for r in results]) + events = yield self.get_events_as_list([r["event_id"] for r in results]) event_map = {ev.event_id: ev for ev in events} diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 9cd1e0f9fe..d105b6b17d 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -319,7 +319,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = yield self.runInteraction("get_room_events_stream_for_room", f) - ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True) + ret = yield self.get_events_as_list([ + r.event_id for r in rows], get_prev_content=True, + ) self._set_before_and_after(ret, rows, topo_order=from_id is None) @@ -367,7 +369,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): rows = yield self.runInteraction("get_membership_changes_for_user", f) - ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True) + ret = yield self.get_events_as_list( + [r.event_id for r in rows], get_prev_content=True, + ) self._set_before_and_after(ret, rows, topo_order=False) @@ -394,7 +398,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): ) logger.debug("stream before") - events = yield self._get_events( + events = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) logger.debug("stream after") @@ -580,11 +584,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events_before = yield self._get_events( + events_before = yield self.get_events_as_list( [e for e in results["before"]["event_ids"]], get_prev_content=True ) - events_after = yield self._get_events( + events_after = yield self.get_events_as_list( [e for e in results["after"]["event_ids"]], get_prev_content=True ) @@ -697,7 +701,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): "get_all_new_events_stream", get_all_new_events_stream_txn ) - events = yield self._get_events(event_ids) + events = yield self.get_events_as_list(event_ids) defer.returnValue((upper_bound, events)) @@ -849,7 +853,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): event_filter, ) - events = yield self._get_events( + events = yield self.get_events_as_list( [r.event_id for r in rows], get_prev_content=True ) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 2f16f23d91..7253ba120f 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -156,6 +156,25 @@ def concurrently_execute(func, args, limit): ], consumeErrors=True)).addErrback(unwrapFirstError) +def yieldable_gather_results(func, iter, *args, **kwargs): + """Executes the function with each argument concurrently. + + Args: + func (func): Function to execute that returns a Deferred + iter (iter): An iterable that yields items that get passed as the first + argument to the function + *args: Arguments to be passed to each call to func + + Returns + Deferred[list]: Resolved when all functions have been invoked, or errors if + one of the function calls fails. + """ + return logcontext.make_deferred_yieldable(defer.gatherResults([ + run_in_background(func, item, *args, **kwargs) + for item in iter + ], consumeErrors=True)).addErrback(unwrapFirstError) + + class Linearizer(object): """Limits concurrent access to resources based on a key. Useful to ensure only a few things happen at a time on a given resource. diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index 194da87639..e14c8bdfda 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -27,6 +27,7 @@ def user_left_room(distributor, user, room_id): distributor.fire("user_left_room", user=user, room_id=room_id) +# XXX: this is no longer used. We should probably kill it. def user_joined_room(distributor, user, room_id): distributor.fire("user_joined_room", user=user, room_id=room_id) |