diff --git a/synapse/__init__.py b/synapse/__init__.py
index cd9cfb2409..bf9e810da6 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -27,4 +27,4 @@ try:
except ImportError:
pass
-__version__ = "0.99.4rc1"
+__version__ = "0.99.4"
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 8547a63535..6b347b1749 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -23,6 +23,9 @@ MAX_DEPTH = 2**63 - 1
# the maximum length for a room alias is 255 characters
MAX_ALIAS_LENGTH = 255
+# the maximum length for a user id is 255 characters
+MAX_USERID_LENGTH = 255
+
class Membership(object):
@@ -116,3 +119,11 @@ class UserTypes(object):
"""
SUPPORT = "support"
ALL_USER_TYPES = (SUPPORT,)
+
+
+class RelationTypes(object):
+ """The types of relations known to this server.
+ """
+ ANNOTATION = "m.annotation"
+ REPLACE = "m.replace"
+ REFERENCE = "m.reference"
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index e77abe1040..485b3d0237 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -19,13 +19,15 @@ class EventFormatVersions(object):
"""This is an internal enum for tracking the version of the event format,
independently from the room version.
"""
- V1 = 1 # $id:server format
- V2 = 2 # MSC1659-style $hash format: introduced for room v3
+ V1 = 1 # $id:server event id format
+ V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
+ V3 = 3 # MSC1884-style $hash format: introduced for room v4
KNOWN_EVENT_FORMAT_VERSIONS = {
EventFormatVersions.V1,
EventFormatVersions.V2,
+ EventFormatVersions.V3,
}
@@ -75,6 +77,12 @@ class RoomVersions(object):
EventFormatVersions.V2,
StateResolutionVersions.V2,
)
+ EVENTID_NOSLASH_TEST = RoomVersion(
+ "eventid-noslash-test",
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ )
# the version we will give rooms which are created on this server
@@ -87,5 +95,6 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V2,
RoomVersions.V3,
RoomVersions.STATE_V2_TEST,
+ RoomVersions.EVENTID_NOSLASH_TEST,
)
} # type: dict[str, RoomVersion]
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index cb71d80875..3c6bddff7a 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -22,8 +22,7 @@ from six.moves.urllib.parse import urlencode
from synapse.config import ConfigError
-CLIENT_PREFIX = "/_matrix/client/api/v1"
-CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
+CLIENT_API_PREFIX = "/_matrix/client"
FEDERATION_PREFIX = "/_matrix/federation"
FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 864f1eac48..707f7c2f07 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -29,6 +29,7 @@ from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.replication.slave.storage import SlavedProfileStore
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -83,6 +84,7 @@ class ClientReaderSlavedStore(
SlavedTransactionStore,
SlavedClientIpStore,
BaseSlavedStore,
+ SlavedProfileStore,
):
pass
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 5a68399e63..5a9adac480 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -16,16 +16,56 @@ from ._base import Config
class RateLimitConfig(object):
- def __init__(self, config):
- self.per_second = config.get("per_second", 0.17)
- self.burst_count = config.get("burst_count", 3.0)
+ def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}):
+ self.per_second = config.get("per_second", defaults["per_second"])
+ self.burst_count = config.get("burst_count", defaults["burst_count"])
-class RatelimitConfig(Config):
+class FederationRateLimitConfig(object):
+ _items_and_default = {
+ "window_size": 10000,
+ "sleep_limit": 10,
+ "sleep_delay": 500,
+ "reject_limit": 50,
+ "concurrent": 3,
+ }
+
+ def __init__(self, **kwargs):
+ for i in self._items_and_default.keys():
+ setattr(self, i, kwargs.get(i) or self._items_and_default[i])
+
+class RatelimitConfig(Config):
def read_config(self, config):
- self.rc_messages_per_second = config.get("rc_messages_per_second", 0.2)
- self.rc_message_burst_count = config.get("rc_message_burst_count", 10.0)
+
+ # Load the new-style messages config if it exists. Otherwise fall back
+ # to the old method.
+ if "rc_message" in config:
+ self.rc_message = RateLimitConfig(
+ config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0}
+ )
+ else:
+ self.rc_message = RateLimitConfig(
+ {
+ "per_second": config.get("rc_messages_per_second", 0.2),
+ "burst_count": config.get("rc_message_burst_count", 10.0),
+ }
+ )
+
+ # Load the new-style federation config, if it exists. Otherwise, fall
+ # back to the old method.
+ if "federation_rc" in config:
+ self.rc_federation = FederationRateLimitConfig(**config["rc_federation"])
+ else:
+ self.rc_federation = FederationRateLimitConfig(
+ **{
+ "window_size": config.get("federation_rc_window_size"),
+ "sleep_limit": config.get("federation_rc_sleep_limit"),
+ "sleep_delay": config.get("federation_rc_sleep_delay"),
+ "reject_limit": config.get("federation_rc_reject_limit"),
+ "concurrent": config.get("federation_rc_concurrent"),
+ }
+ )
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
@@ -33,38 +73,26 @@ class RatelimitConfig(Config):
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
self.rc_login_failed_attempts = RateLimitConfig(
- rc_login_config.get("failed_attempts", {}),
+ rc_login_config.get("failed_attempts", {})
)
- self.federation_rc_window_size = config.get("federation_rc_window_size", 1000)
- self.federation_rc_sleep_limit = config.get("federation_rc_sleep_limit", 10)
- self.federation_rc_sleep_delay = config.get("federation_rc_sleep_delay", 500)
- self.federation_rc_reject_limit = config.get("federation_rc_reject_limit", 50)
- self.federation_rc_concurrent = config.get("federation_rc_concurrent", 3)
-
self.federation_rr_transactions_per_room_per_second = config.get(
- "federation_rr_transactions_per_room_per_second", 50,
+ "federation_rr_transactions_per_room_per_second", 50
)
def default_config(self, **kwargs):
return """\
## Ratelimiting ##
- # Number of messages a client can send per second
- #
- #rc_messages_per_second: 0.2
-
- # Number of message a client can send before being throttled
- #
- #rc_message_burst_count: 10.0
-
- # Ratelimiting settings for registration and login.
+ # Ratelimiting settings for client actions (registration, login, messaging).
#
# Each ratelimiting configuration is made of two parameters:
# - per_second: number of requests a client can send per second.
# - burst_count: number of requests a client can send before being throttled.
#
# Synapse currently uses the following configurations:
+ # - one for messages that ratelimits sending based on the account the client
+ # is using
# - one for registration that ratelimits registration requests based on the
# client's IP address.
# - one for login that ratelimits login requests based on the client's IP
@@ -77,6 +105,10 @@ class RatelimitConfig(Config):
#
# The defaults are as shown below.
#
+ #rc_message:
+ # per_second: 0.2
+ # burst_count: 10
+ #
#rc_registration:
# per_second: 0.17
# burst_count: 3
@@ -92,29 +124,28 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
- # The federation window size in milliseconds
- #
- #federation_rc_window_size: 1000
-
- # The number of federation requests from a single server in a window
- # before the server will delay processing the request.
- #
- #federation_rc_sleep_limit: 10
- # The duration in milliseconds to delay processing events from
- # remote servers by if they go over the sleep limit.
+ # Ratelimiting settings for incoming federation
#
- #federation_rc_sleep_delay: 500
-
- # The maximum number of concurrent federation requests allowed
- # from a single server
+ # The rc_federation configuration is made up of the following settings:
+ # - window_size: window size in milliseconds
+ # - sleep_limit: number of federation requests from a single server in
+ # a window before the server will delay processing the request.
+ # - sleep_delay: duration in milliseconds to delay processing events
+ # from remote servers by if they go over the sleep limit.
+ # - reject_limit: maximum number of concurrent federation requests
+ # allowed from a single server
+ # - concurrent: number of federation requests to concurrently process
+ # from a single server
#
- #federation_rc_reject_limit: 50
-
- # The number of federation requests to concurrently process from a
- # single server
+ # The defaults are as shown below.
#
- #federation_rc_concurrent: 3
+ #rc_federation:
+ # window_size: 1000
+ # sleep_limit: 10
+ # sleep_delay: 500
+ # reject_limit: 50
+ # concurrent: 3
# Target outgoing federation transaction frequency for sending read-receipts,
# per-room.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 8dce75c56a..f34aa42afa 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,6 +18,8 @@
import logging
import os.path
+from netaddr import IPSet
+
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.python_dependencies import DependencyException, check_requirements
@@ -98,6 +101,11 @@ class ServerConfig(Config):
"block_non_admin_invites", False,
)
+ # Whether to enable experimental MSC1849 (aka relations) support
+ self.experimental_msc1849_support_enabled = config.get(
+ "experimental_msc1849_support_enabled", False,
+ )
+
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = 0
@@ -137,6 +145,24 @@ class ServerConfig(Config):
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
+ self.federation_ip_range_blacklist = config.get(
+ "federation_ip_range_blacklist", [],
+ )
+
+ # Attempt to create an IPSet from the given ranges
+ try:
+ self.federation_ip_range_blacklist = IPSet(
+ self.federation_ip_range_blacklist
+ )
+
+ # Always blacklist 0.0.0.0, ::
+ self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+ except Exception as e:
+ raise ConfigError(
+ "Invalid range(s) provided in "
+ "federation_ip_range_blacklist: %s" % e
+ )
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
@@ -153,6 +179,10 @@ class ServerConfig(Config):
"require_membership_for_aliases", True,
)
+ # Whether to allow per-room membership profiles through the send of membership
+ # events with profile information that differ from the target's global profile.
+ self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+
self.listeners = []
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
@@ -386,6 +416,24 @@ class ServerConfig(Config):
# - nyc.example.com
# - syd.example.com
+ # Prevent federation requests from being sent to the following
+ # blacklist IP address CIDR ranges. If this option is not specified, or
+ # specified with an empty list, no ip range blacklist will be enforced.
+ #
+ # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+ # listed here, since they correspond to unroutable addresses.)
+ #
+ federation_ip_range_blacklist:
+ - '127.0.0.0/8'
+ - '10.0.0.0/8'
+ - '172.16.0.0/12'
+ - '192.168.0.0/16'
+ - '100.64.0.0/10'
+ - '169.254.0.0/16'
+ - '::1/128'
+ - 'fe80::/64'
+ - 'fc00::/7'
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
@@ -528,6 +576,12 @@ class ServerConfig(Config):
# Defaults to 'true'.
#
#require_membership_for_aliases: false
+
+ # Whether to allow per-room membership profiles through the send of membership
+ # events with profile information that differ from the target's global profile.
+ # Defaults to 'true'.
+ #
+ #allow_per_room_profiles: false
""" % locals()
def read_arguments(self, args):
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 12056d5be2..badeb903fc 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -335,13 +335,32 @@ class FrozenEventV2(EventBase):
return self.__repr__()
def __repr__(self):
- return "<FrozenEventV2 event_id='%s', type='%s', state_key='%s'>" % (
+ return "<%s event_id='%s', type='%s', state_key='%s'>" % (
+ self.__class__.__name__,
self.event_id,
self.get("type", None),
self.get("state_key", None),
)
+class FrozenEventV3(FrozenEventV2):
+ """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
+ format_version = EventFormatVersions.V3 # All events of this type are V3
+
+ @property
+ def event_id(self):
+ # We have to import this here as otherwise we get an import loop which
+ # is hard to break.
+ from synapse.crypto.event_signing import compute_event_reference_hash
+
+ if self._event_id:
+ return self._event_id
+ self._event_id = "$" + encode_base64(
+ compute_event_reference_hash(self)[1], urlsafe=True
+ )
+ return self._event_id
+
+
def room_version_to_event_format(room_version):
"""Converts a room version string to the event format
@@ -376,6 +395,8 @@ def event_type_from_format_version(format_version):
return FrozenEvent
elif format_version == EventFormatVersions.V2:
return FrozenEventV2
+ elif format_version == EventFormatVersions.V3:
+ return FrozenEventV3
else:
raise Exception(
"No event format %r" % (format_version,)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 07fccdd8f9..27a2a9ef98 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -19,7 +19,10 @@ from six import string_types
from frozendict import frozendict
-from synapse.api.constants import EventTypes
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.util.async_helpers import yieldable_gather_results
from . import EventBase
@@ -311,3 +314,92 @@ def serialize_event(e, time_now_ms, as_client_event=True,
d = only_fields(d, only_event_fields)
return d
+
+
+class EventClientSerializer(object):
+ """Serializes events that are to be sent to clients.
+
+ This is used for bundling extra information with any events to be sent to
+ clients.
+ """
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.experimental_msc1849_support_enabled = (
+ hs.config.experimental_msc1849_support_enabled
+ )
+
+ @defer.inlineCallbacks
+ def serialize_event(self, event, time_now, **kwargs):
+ """Serializes a single event.
+
+ Args:
+ event (EventBase)
+ time_now (int): The current time in milliseconds
+ **kwargs: Arguments to pass to `serialize_event`
+
+ Returns:
+ Deferred[dict]: The serialized event
+ """
+ # To handle the case of presence events and the like
+ if not isinstance(event, EventBase):
+ defer.returnValue(event)
+
+ event_id = event.event_id
+ serialized_event = serialize_event(event, time_now, **kwargs)
+
+ # If MSC1849 is enabled then we need to look if thre are any relations
+ # we need to bundle in with the event
+ if self.experimental_msc1849_support_enabled:
+ annotations = yield self.store.get_aggregation_groups_for_event(
+ event_id,
+ )
+ references = yield self.store.get_relations_for_event(
+ event_id, RelationTypes.REFERENCE, direction="f",
+ )
+
+ if annotations.chunk:
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
+ r[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+ if references.chunk:
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
+ r[RelationTypes.REFERENCE] = references.to_dict()
+
+ edit = None
+ if event.type == EventTypes.Message:
+ edit = yield self.store.get_applicable_edit(event_id)
+
+ if edit:
+ # If there is an edit replace the content, preserving existing
+ # relations.
+
+ relations = event.content.get("m.relates_to")
+ serialized_event["content"] = edit.content.get("m.new_content", {})
+ if relations:
+ serialized_event["content"]["m.relates_to"] = relations
+ else:
+ serialized_event["content"].pop("m.relates_to", None)
+
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
+ r[RelationTypes.REPLACE] = {
+ "event_id": edit.event_id,
+ }
+
+ defer.returnValue(serialized_event)
+
+ def serialize_events(self, events, time_now, **kwargs):
+ """Serializes multiple events.
+
+ Args:
+ event (iter[EventBase])
+ time_now (int): The current time in milliseconds
+ **kwargs: Arguments to pass to `serialize_event`
+
+ Returns:
+ Deferred[list[dict]]: The list of serialized events
+ """
+ return yieldable_gather_results(
+ self.serialize_event, events,
+ time_now=time_now, **kwargs
+ )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 9030eb18c5..385eda2dca 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -63,11 +63,7 @@ class TransportLayerServer(JsonResource):
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
self.clock,
- window_size=hs.config.federation_rc_window_size,
- sleep_limit=hs.config.federation_rc_sleep_limit,
- sleep_msec=hs.config.federation_rc_sleep_delay,
- reject_limit=hs.config.federation_rc_reject_limit,
- concurrent_requests=hs.config.federation_rc_concurrent,
+ config=hs.config.rc_federation,
)
self.register_servlets()
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index ac09d03ba9..dca337ec61 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -90,8 +90,8 @@ class BaseHandler(object):
messages_per_second = override.messages_per_second
burst_count = override.burst_count
else:
- messages_per_second = self.hs.config.rc_messages_per_second
- burst_count = self.hs.config.rc_message_burst_count
+ messages_per_second = self.hs.config.rc_message.per_second
+ burst_count = self.hs.config.rc_message.burst_count
allowed, time_allowed = self.ratelimiter.can_do_action(
user_id, time_now,
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 1b4d8c74ae..6003ad9cca 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -21,7 +21,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
-from synapse.events.utils import serialize_event
from synapse.types import UserID
from synapse.util.logutils import log_function
from synapse.visibility import filter_events_for_client
@@ -50,6 +49,7 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self._server_notices_sender = hs.get_server_notices_sender()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
@log_function
@@ -120,9 +120,9 @@ class EventStreamHandler(BaseHandler):
time_now = self.clock.time_msec()
- chunks = [
- serialize_event(e, time_now, as_client_event) for e in events
- ]
+ chunks = yield self._event_serializer.serialize_events(
+ events, time_now, as_client_event=as_client_event,
+ )
chunk = {
"chunk": chunks,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 7dfae78db0..aaee5db0b7 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -19,7 +19,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig
@@ -43,6 +42,7 @@ class InitialSyncHandler(BaseHandler):
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
+ self._event_serializer = hs.get_event_client_serializer()
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
@@ -138,7 +138,9 @@ class InitialSyncHandler(BaseHandler):
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
- d["invite"] = serialize_event(invite_event, time_now, as_client_event)
+ d["invite"] = yield self._event_serializer.serialize_event(
+ invite_event, time_now, as_client_event,
+ )
rooms_ret.append(d)
@@ -185,18 +187,21 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
d["messages"] = {
- "chunk": [
- serialize_event(m, time_now, as_client_event)
- for m in messages
- ],
+ "chunk": (
+ yield self._event_serializer.serialize_events(
+ messages, time_now=time_now,
+ as_client_event=as_client_event,
+ )
+ ),
"start": start_token.to_string(),
"end": end_token.to_string(),
}
- d["state"] = [
- serialize_event(c, time_now, as_client_event)
- for c in current_state.values()
- ]
+ d["state"] = yield self._event_serializer.serialize_events(
+ current_state.values(),
+ time_now=time_now,
+ as_client_event=as_client_event
+ )
account_data_events = []
tags = tags_by_room.get(event.room_id)
@@ -337,11 +342,15 @@ class InitialSyncHandler(BaseHandler):
"membership": membership,
"room_id": room_id,
"messages": {
- "chunk": [serialize_event(m, time_now) for m in messages],
+ "chunk": (yield self._event_serializer.serialize_events(
+ messages, time_now,
+ )),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
- "state": [serialize_event(s, time_now) for s in room_state.values()],
+ "state": (yield self._event_serializer.serialize_events(
+ room_state.values(), time_now,
+ )),
"presence": [],
"receipts": [],
})
@@ -355,10 +364,9 @@ class InitialSyncHandler(BaseHandler):
# TODO: These concurrently
time_now = self.clock.time_msec()
- state = [
- serialize_event(x, time_now)
- for x in current_state.values()
- ]
+ state = yield self._event_serializer.serialize_events(
+ current_state.values(), time_now,
+ )
now_token = yield self.hs.get_event_sources().get_current_token()
@@ -425,7 +433,9 @@ class InitialSyncHandler(BaseHandler):
ret = {
"room_id": room_id,
"messages": {
- "chunk": [serialize_event(m, time_now) for m in messages],
+ "chunk": (yield self._event_serializer.serialize_events(
+ messages, time_now,
+ )),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 8f16e12430..33ab23a0e1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -32,7 +32,6 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import ConsentURIBuilder
-from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.state import StateFilter
@@ -57,6 +56,7 @@ class MessageHandler(object):
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastore()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None,
@@ -164,9 +164,10 @@ class MessageHandler(object):
room_state = room_state[membership_event_id]
now = self.clock.time_msec()
- defer.returnValue(
- [serialize_event(c, now) for c in room_state.values()]
+ events = yield self._event_serializer.serialize_events(
+ room_state.values(), now,
)
+ defer.returnValue(events)
@defer.inlineCallbacks
def get_joined_members(self, requester, room_id):
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index e4fdae9266..8f811e24fe 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -20,7 +20,6 @@ from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
-from synapse.events.utils import serialize_event
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
@@ -78,6 +77,7 @@ class PaginationHandler(object):
self._purges_in_progress_by_room = set()
# map from purge id to PurgeStatus
self._purges_by_id = {}
+ self._event_serializer = hs.get_event_client_serializer()
def start_purge_history(self, room_id, token,
delete_local_events=False):
@@ -278,18 +278,22 @@ class PaginationHandler(object):
time_now = self.clock.time_msec()
chunk = {
- "chunk": [
- serialize_event(e, time_now, as_client_event)
- for e in events
- ],
+ "chunk": (
+ yield self._event_serializer.serialize_events(
+ events, time_now,
+ as_client_event=as_client_event,
+ )
+ ),
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
}
if state:
- chunk["state"] = [
- serialize_event(e, time_now, as_client_event)
- for e in state
- ]
+ chunk["state"] = (
+ yield self._event_serializer.serialize_events(
+ state, time_now,
+ as_client_event=as_client_event,
+ )
+ )
defer.returnValue(chunk)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a51d11a257..e83ee24f10 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -19,7 +19,7 @@ import logging
from twisted.internet import defer
from synapse import types
-from synapse.api.constants import LoginType
+from synapse.api.constants import MAX_USERID_LENGTH, LoginType
from synapse.api.errors import (
AuthError,
Codes,
@@ -123,6 +123,15 @@ class RegistrationHandler(BaseHandler):
self.check_user_id_not_appservice_exclusive(user_id)
+ if len(user_id) > MAX_USERID_LENGTH:
+ raise SynapseError(
+ 400,
+ "User ID may not be longer than %s characters" % (
+ MAX_USERID_LENGTH,
+ ),
+ Codes.INVALID_USERNAME
+ )
+
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
if not guest_access_token:
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e11511d395..fb8cdbc03f 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -74,6 +75,7 @@ class RoomMemberHandler(object):
self.spam_checker = hs.get_spam_checker()
self._server_notices_mxid = self.config.server_notices_mxid
self._enable_lookup = hs.config.enable_3pid_lookup
+ self.allow_per_room_profiles = self.config.allow_per_room_profiles
# This is only used to get at ratelimit function, and
# maybe_kick_guest_users. It's fine there are multiple of these as
@@ -377,6 +379,13 @@ class RoomMemberHandler(object):
# later on.
content = dict(content)
+ if not self.allow_per_room_profiles:
+ # Strip profile data, knowing that new profile data will be added to the
+ # event's content in event_creation_handler.create_event() using the target's
+ # global profile.
+ content.pop("displayname", None)
+ content.pop("avatar_url", None)
+
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
@@ -955,7 +964,7 @@ class RoomMemberHandler(object):
}
if self.config.invite_3pid_guest:
- guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
+ guest_user_id, guest_access_token = yield self.get_or_register_3pid_guest(
requester=requester,
medium=medium,
address=address,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 49c439313e..9bba74d6c9 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -23,7 +23,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
-from synapse.events.utils import serialize_event
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -36,6 +35,7 @@ class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -401,14 +401,16 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = [
- serialize_event(e, time_now)
- for e in context["events_before"]
- ]
- context["events_after"] = [
- serialize_event(e, time_now)
- for e in context["events_after"]
- ]
+ context["events_before"] = (
+ yield self._event_serializer.serialize_events(
+ context["events_before"], time_now,
+ )
+ )
+ context["events_after"] = (
+ yield self._event_serializer.serialize_events(
+ context["events_after"], time_now,
+ )
+ )
state_results = {}
if include_state:
@@ -422,14 +424,13 @@ class SearchHandler(BaseHandler):
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
- results = [
- {
+ results = []
+ for e in allowed_events:
+ results.append({
"rank": rank_map[e.event_id],
- "result": serialize_event(e, time_now),
+ "result": (yield self._event_serializer.serialize_event(e, time_now)),
"context": contexts.get(e.event_id, {}),
- }
- for e in allowed_events
- ]
+ })
rooms_cat_res = {
"results": results,
@@ -438,10 +439,13 @@ class SearchHandler(BaseHandler):
}
if state_results:
- rooms_cat_res["state"] = {
- room_id: [serialize_event(e, time_now) for e in state]
- for room_id, state in state_results.items()
- }
+ s = {}
+ for room_id, state in state_results.items():
+ s[room_id] = yield self._event_serializer.serialize_events(
+ state, time_now,
+ )
+
+ rooms_cat_res["state"] = s
if room_groups and "room_id" in group_keys:
rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9bd8f53ec8..d915faba03 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -937,7 +937,7 @@ class SyncHandler(object):
res = yield self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
- newly_joined_rooms, newly_joined_users, _, _ = res
+ newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
@@ -946,7 +946,7 @@ class SyncHandler(object):
)
if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence(
- sync_result_builder, newly_joined_rooms, newly_joined_users
+ sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
)
yield self._generate_sync_entry_for_to_device(sync_result_builder)
@@ -954,7 +954,7 @@ class SyncHandler(object):
device_lists = yield self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
- newly_joined_users=newly_joined_users,
+ newly_joined_or_invited_users=newly_joined_or_invited_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
@@ -1039,7 +1039,8 @@ class SyncHandler(object):
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder,
- newly_joined_rooms, newly_joined_users,
+ newly_joined_rooms,
+ newly_joined_or_invited_users,
newly_left_rooms, newly_left_users):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -1053,7 +1054,7 @@ class SyncHandler(object):
# share a room with?
for room_id in newly_joined_rooms:
joined_users = yield self.state.get_current_users_in_room(room_id)
- newly_joined_users.update(joined_users)
+ newly_joined_or_invited_users.update(joined_users)
for room_id in newly_left_rooms:
left_users = yield self.state.get_current_users_in_room(room_id)
@@ -1061,7 +1062,7 @@ class SyncHandler(object):
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
- changed.update(newly_joined_users)
+ changed.update(newly_joined_or_invited_users)
if not changed and not newly_left_users:
defer.returnValue(DeviceLists(
@@ -1179,7 +1180,7 @@ class SyncHandler(object):
@defer.inlineCallbacks
def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms,
- newly_joined_users):
+ newly_joined_or_invited_users):
"""Generates the presence portion of the sync response. Populates the
`sync_result_builder` with the result.
@@ -1187,8 +1188,9 @@ class SyncHandler(object):
sync_result_builder(SyncResultBuilder)
newly_joined_rooms(list): List of rooms that the user has joined
since the last sync (or empty if an initial sync)
- newly_joined_users(list): List of users that have joined rooms
- since the last sync (or empty if an initial sync)
+ newly_joined_or_invited_users(list): List of users that have joined
+ or been invited to rooms since the last sync (or empty if an initial
+ sync)
"""
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
@@ -1214,7 +1216,7 @@ class SyncHandler(object):
"presence_key", presence_key
)
- extra_users_ids = set(newly_joined_users)
+ extra_users_ids = set(newly_joined_or_invited_users)
for room_id in newly_joined_rooms:
users = yield self.state.get_current_users_in_room(room_id)
extra_users_ids.update(users)
@@ -1246,7 +1248,8 @@ class SyncHandler(object):
Returns:
Deferred(tuple): Returns a 4-tuple of
- `(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)`
+ `(newly_joined_rooms, newly_joined_or_invited_users,
+ newly_left_rooms, newly_left_users)`
"""
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
@@ -1317,8 +1320,8 @@ class SyncHandler(object):
sync_result_builder.invited.extend(invited)
- # Now we want to get any newly joined users
- newly_joined_users = set()
+ # Now we want to get any newly joined or invited users
+ newly_joined_or_invited_users = set()
newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
@@ -1327,19 +1330,22 @@ class SyncHandler(object):
)
for event in it:
if event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- newly_joined_users.add(event.state_key)
+ if (
+ event.membership == Membership.JOIN or
+ event.membership == Membership.INVITE
+ ):
+ newly_joined_or_invited_users.add(event.state_key)
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
- newly_left_users -= newly_joined_users
+ newly_left_users -= newly_joined_or_invited_users
defer.returnValue((
newly_joined_rooms,
- newly_joined_users,
+ newly_joined_or_invited_users,
newly_left_rooms,
newly_left_users,
))
@@ -1384,7 +1390,7 @@ class SyncHandler(object):
where:
room_entries is a list [RoomSyncResultBuilder]
invited_rooms is a list [InvitedSyncResult]
- newly_joined rooms is a list[str] of room ids
+ newly_joined_rooms is a list[str] of room ids
newly_left_rooms is a list[str] of room ids
"""
user_id = sync_result_builder.sync_config.user.to_string()
@@ -1425,7 +1431,7 @@ class SyncHandler(object):
if room_id in sync_result_builder.joined_room_ids and non_joins:
# Always include if the user (re)joined the room, especially
# important so that device list changes are calculated correctly.
- # If there are non join member events, but we are still in the room,
+ # If there are non-join member events, but we are still in the room,
# then the user must have left and joined
newly_joined_rooms.append(room_id)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ddbfb72228..77fe68818b 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -165,7 +165,8 @@ class BlacklistingAgentWrapper(Agent):
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
- "Blocking access to %s because of blacklist" % (ip_address,)
+ "Blocking access to %s due to blacklist" %
+ (ip_address,)
)
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
@@ -263,9 +264,6 @@ class SimpleHttpClient(object):
uri (str): URI to query.
data (bytes): Data to send in the request body, if applicable.
headers (t.w.http_headers.Headers): Request headers.
-
- Raises:
- SynapseError: If the IP is blacklisted.
"""
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index ff63d0b2a8..7eefc7b1fc 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -27,9 +27,11 @@ import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
+from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError
+from twisted.internet.interfaces import IReactorPluggableNameResolver
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
@@ -44,6 +46,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.http import QuieterFileBodyProducer
+from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import make_deferred_yieldable
@@ -172,19 +175,44 @@ class MatrixFederationHttpClient(object):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
- reactor = hs.get_reactor()
+
+ real_reactor = hs.get_reactor()
+
+ # We need to use a DNS resolver which filters out blacklisted IP
+ # addresses, to prevent DNS rebinding.
+ nameResolver = IPBlacklistingResolver(
+ real_reactor, None, hs.config.federation_ip_range_blacklist,
+ )
+
+ @implementer(IReactorPluggableNameResolver)
+ class Reactor(object):
+ def __getattr__(_self, attr):
+ if attr == "nameResolver":
+ return nameResolver
+ else:
+ return getattr(real_reactor, attr)
+
+ self.reactor = Reactor()
self.agent = MatrixFederationAgent(
- hs.get_reactor(),
+ self.reactor,
tls_client_options_factory,
)
+
+ # Use a BlacklistingAgentWrapper to prevent circumventing the IP
+ # blacklist via IP literals in server names
+ self.agent = BlacklistingAgentWrapper(
+ self.agent, self.reactor,
+ ip_blacklist=hs.config.federation_ip_range_blacklist,
+ )
+
self.clock = hs.get_clock()
self._store = hs.get_datastore()
self.version_string_bytes = hs.version_string.encode('ascii')
self.default_timeout = 60
def schedule(x):
- reactor.callLater(_EPSILON, x)
+ self.reactor.callLater(_EPSILON, x)
self._cooperator = Cooperator(scheduler=schedule)
@@ -370,7 +398,7 @@ class MatrixFederationHttpClient(object):
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
- reactor=self.hs.get_reactor(),
+ reactor=self.reactor,
)
response = yield request_deferred
@@ -397,7 +425,7 @@ class MatrixFederationHttpClient(object):
d = timeout_deferred(
d,
timeout=_sec_timeout,
- reactor=self.hs.get_reactor(),
+ reactor=self.reactor,
)
try:
@@ -586,7 +614,7 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.hs.get_reactor(), self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response,
)
defer.returnValue(body)
@@ -645,7 +673,7 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
body = yield _handle_json_response(
- self.hs.get_reactor(), _sec_timeout, request, response,
+ self.reactor, _sec_timeout, request, response,
)
defer.returnValue(body)
@@ -704,7 +732,7 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.hs.get_reactor(), self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response,
)
defer.returnValue(body)
@@ -753,7 +781,7 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.hs.get_reactor(), self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response,
)
defer.returnValue(body)
@@ -801,7 +829,7 @@ class MatrixFederationHttpClient(object):
try:
d = _readBodyToFile(response, output_stream, max_size)
- d.addTimeout(self.default_timeout, self.hs.get_reactor())
+ d.addTimeout(self.default_timeout, self.reactor)
length = yield make_deferred_yieldable(d)
except Exception as e:
logger.warn(
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 2708f5e820..fdcfb90a7e 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -53,7 +53,7 @@ REQUIREMENTS = [
"pyasn1-modules>=0.0.7",
"daemonize>=2.3.1",
"bcrypt>=3.1.0",
- "pillow>=3.1.2",
+ "pillow>=4.3.0",
"sortedcontainers>=1.4.4",
"psutil>=2.0.0",
"pymacaroons>=0.13.0",
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index b457c5563f..a3952506c1 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
from synapse.storage.event_federation import EventFederationWorkerStore
from synapse.storage.event_push_actions import EventPushActionsWorkerStore
from synapse.storage.events_worker import EventsWorkerStore
+from synapse.storage.relations import RelationsWorkerStore
from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
from synapse.storage.state import StateGroupWorkerStore
@@ -52,6 +53,7 @@ class SlavedEventStore(EventFederationWorkerStore,
EventsWorkerStore,
SignatureWorkerStore,
UserErasureWorkerStore,
+ RelationsWorkerStore,
BaseSlavedStore):
def __init__(self, db_conn, hs):
@@ -89,7 +91,7 @@ class SlavedEventStore(EventFederationWorkerStore,
for row in rows:
self.invalidate_caches_for_event(
-token, row.event_id, row.room_id, row.type, row.state_key,
- row.redacts,
+ row.redacts, row.relates_to,
backfilled=True,
)
return super(SlavedEventStore, self).process_replication_rows(
@@ -102,7 +104,7 @@ class SlavedEventStore(EventFederationWorkerStore,
if row.type == EventsStreamEventRow.TypeId:
self.invalidate_caches_for_event(
token, data.event_id, data.room_id, data.type, data.state_key,
- data.redacts,
+ data.redacts, data.relates_to,
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
@@ -114,7 +116,8 @@ class SlavedEventStore(EventFederationWorkerStore,
raise Exception("Unknown events stream row type %s" % (row.type, ))
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
- etype, state_key, redacts, backfilled):
+ etype, state_key, redacts, relates_to,
+ backfilled):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((room_id,))
@@ -136,3 +139,8 @@ class SlavedEventStore(EventFederationWorkerStore,
state_key, stream_ordering
)
self.get_invited_rooms_for_user.invalidate((state_key,))
+
+ if relates_to:
+ self.get_relations_for_event.invalidate_many((relates_to,))
+ self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+ self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 52564136fc..5dfead09ca 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -32,6 +32,7 @@ BackfillStreamRow = namedtuple("BackfillStreamRow", (
"type", # str
"state_key", # str, optional
"redacts", # str, optional
+ "relates_to", # str, optional
))
PresenceStreamRow = namedtuple("PresenceStreamRow", (
"user_id", # str
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index e0f6e29248..f1290d022a 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -80,11 +80,12 @@ class BaseEventsStreamRow(object):
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib() # str
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str, optional
- redacts = attr.ib() # str, optional
+ event_id = attr.ib() # str
+ room_id = attr.ib() # str
+ type = attr.ib() # str
+ state_key = attr.ib() # str, optional
+ redacts = attr.ib() # str, optional
+ relates_to = attr.ib() # str, optional
@attr.s(slots=True, frozen=True)
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 3a24d31d1b..e6110ad9b1 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -44,6 +44,7 @@ from synapse.rest.client.v2_alpha import (
read_marker,
receipts,
register,
+ relations,
report_event,
room_keys,
room_upgrade_rest_servlet,
@@ -115,6 +116,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ relations.register_servlets(hs, client_resource)
# moving to /_synapse/admin
synapse.rest.admin.register_servlets_for_client_rest_resource(
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index c77d7aba68..dc63b661c0 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -19,7 +19,7 @@
import logging
import re
-from synapse.api.urls import CLIENT_PREFIX
+from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.servlet import RestServlet
from synapse.rest.client.transactions import HttpTransactionCache
@@ -36,12 +36,12 @@ def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True):
Returns:
SRE_Pattern
"""
- patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
+ patterns = [re.compile("^" + CLIENT_API_PREFIX + "/api/v1" + path_regex)]
if include_in_unstable:
- unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
+ unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
- new_prefix = CLIENT_PREFIX.replace("/api/v1", "/r%d" % release)
+ new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index cd9b3bdbd1..c3b0a39ab7 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -19,7 +19,6 @@ import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError
-from synapse.events.utils import serialize_event
from synapse.streams.config import PaginationConfig
from .base import ClientV1RestServlet, client_path_patterns
@@ -84,6 +83,7 @@ class EventRestServlet(ClientV1RestServlet):
super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request, event_id):
@@ -92,7 +92,8 @@ class EventRestServlet(ClientV1RestServlet):
time_now = self.clock.time_msec()
if event:
- defer.returnValue((200, serialize_event(event, time_now)))
+ event = yield self._event_serializer.serialize_event(event, time_now)
+ defer.returnValue((200, event))
else:
defer.returnValue((404, "Event not found."))
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index fab04965cb..255a85c588 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -26,7 +26,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
-from synapse.events.utils import format_event_for_client_v2, serialize_event
+from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
assert_params_in_dict,
parse_integer,
@@ -537,6 +537,7 @@ class RoomEventServlet(ClientV1RestServlet):
super(RoomEventServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@@ -545,7 +546,8 @@ class RoomEventServlet(ClientV1RestServlet):
time_now = self.clock.time_msec()
if event:
- defer.returnValue((200, serialize_event(event, time_now)))
+ event = yield self._event_serializer.serialize_event(event, time_now)
+ defer.returnValue((200, event))
else:
defer.returnValue((404, "Event not found."))
@@ -559,6 +561,7 @@ class RoomEventContextServlet(ClientV1RestServlet):
super(RoomEventContextServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@@ -588,16 +591,18 @@ class RoomEventContextServlet(ClientV1RestServlet):
)
time_now = self.clock.time_msec()
- results["events_before"] = [
- serialize_event(event, time_now) for event in results["events_before"]
- ]
- results["event"] = serialize_event(results["event"], time_now)
- results["events_after"] = [
- serialize_event(event, time_now) for event in results["events_after"]
- ]
- results["state"] = [
- serialize_event(event, time_now) for event in results["state"]
- ]
+ results["events_before"] = yield self._event_serializer.serialize_events(
+ results["events_before"], time_now,
+ )
+ results["event"] = yield self._event_serializer.serialize_event(
+ results["event"], time_now,
+ )
+ results["events_after"] = yield self._event_serializer.serialize_events(
+ results["events_after"], time_now,
+ )
+ results["state"] = yield self._event_serializer.serialize_events(
+ results["state"], time_now,
+ )
defer.returnValue((200, results))
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 77434937ff..24ac26bf03 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -21,13 +21,12 @@ import re
from twisted.internet import defer
from synapse.api.errors import InteractiveAuthIncompleteError
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+from synapse.api.urls import CLIENT_API_PREFIX
logger = logging.getLogger(__name__)
def client_v2_patterns(path_regex, releases=(0,),
- v2_alpha=True,
unstable=True):
"""Creates a regex compiled client path with the correct client path
prefix.
@@ -39,13 +38,11 @@ def client_v2_patterns(path_regex, releases=(0,),
SRE_Pattern
"""
patterns = []
- if v2_alpha:
- patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
if unstable:
- unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
+ unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
for release in releases:
- new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
+ new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index ac035c7735..4c380ab84d 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_string
@@ -139,8 +139,8 @@ class AuthRestServlet(RestServlet):
if stagetype == LoginType.RECAPTCHA:
html = RECAPTCHA_TEMPLATE % {
'session': session,
- 'myurl': "%s/auth/%s/fallback/web" % (
- CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
+ 'myurl': "%s/r0/auth/%s/fallback/web" % (
+ CLIENT_API_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
@@ -159,8 +159,8 @@ class AuthRestServlet(RestServlet):
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
- 'myurl': "%s/auth/%s/fallback/web" % (
- CLIENT_V2_ALPHA_PREFIX, LoginType.TERMS
+ 'myurl': "%s/r0/auth/%s/fallback/web" % (
+ CLIENT_API_PREFIX, LoginType.TERMS
),
}
html_bytes = html.encode("utf8")
@@ -203,8 +203,8 @@ class AuthRestServlet(RestServlet):
else:
html = RECAPTCHA_TEMPLATE % {
'session': session,
- 'myurl': "%s/auth/%s/fallback/web" % (
- CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
+ 'myurl': "%s/r0/auth/%s/fallback/web" % (
+ CLIENT_API_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
@@ -240,8 +240,8 @@ class AuthRestServlet(RestServlet):
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
- 'myurl': "%s/auth/%s/fallback/web" % (
- CLIENT_V2_ALPHA_PREFIX, LoginType.TERMS
+ 'myurl': "%s/r0/auth/%s/fallback/web" % (
+ CLIENT_API_PREFIX, LoginType.TERMS
),
}
html_bytes = html.encode("utf8")
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 9b75bb1377..5a5be7c390 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
+ PATTERNS = client_v2_patterns("/devices$")
def __init__(self, hs):
"""
@@ -56,7 +56,7 @@ class DeleteDevicesRestServlet(RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
- PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
+ PATTERNS = client_v2_patterns("/delete_devices")
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
@@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
+ PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index 2a6ea3df5f..0a1eb0ae45 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -17,10 +17,7 @@ import logging
from twisted.internet import defer
-from synapse.events.utils import (
- format_event_for_client_v2_without_room_id,
- serialize_event,
-)
+from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from ._base import client_v2_patterns
@@ -36,6 +33,7 @@ class NotificationsServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -69,11 +67,11 @@ class NotificationsServlet(RestServlet):
"profile_tag": pa["profile_tag"],
"actions": pa["actions"],
"ts": pa["received_ts"],
- "event": serialize_event(
+ "event": (yield self._event_serializer.serialize_event(
notif_events[pa["event_id"]],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
- ),
+ )),
}
if pa["room_id"] not in receipts_by_room:
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index dc3e265bcd..042f636135 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -31,6 +31,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
+from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import (
RestServlet,
@@ -153,16 +154,18 @@ class UsernameAvailabilityRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter(
hs.get_clock(),
- # Time window of 2s
- window_size=2000,
- # Artificially delay requests if rate > sleep_limit/window_size
- sleep_limit=1,
- # Amount of artificial delay to apply
- sleep_msec=1000,
- # Error with 429 if more than reject_limit requests are queued
- reject_limit=1,
- # Allow 1 request at a time
- concurrent_requests=1,
+ FederationRateLimitConfig(
+ # Time window of 2s
+ window_size=2000,
+ # Artificially delay requests if rate > sleep_limit/window_size
+ sleep_limit=1,
+ # Amount of artificial delay to apply
+ sleep_msec=1000,
+ # Error with 429 if more than reject_limit requests are queued
+ reject_limit=1,
+ # Allow 1 request at a time
+ concurrent_requests=1,
+ )
)
@defer.inlineCallbacks
@@ -345,18 +348,22 @@ class RegisterRestServlet(RestServlet):
if self.hs.config.enable_registration_captcha:
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
- flows.extend([[LoginType.RECAPTCHA]])
+ # Also add a dummy flow here, otherwise if a client completes
+ # recaptcha first we'll assume they were going for this flow
+ # and complete the request, when they could have been trying to
+ # complete one of the flows with email/msisdn auth.
+ flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]])
# only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
- flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+ flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]])
if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email:
- flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+ flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
# always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
+ [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
])
else:
# only support 3PIDless registration if no 3PIDs are required
@@ -379,7 +386,15 @@ class RegisterRestServlet(RestServlet):
if self.hs.config.user_consent_at_registration:
new_flows = []
for flow in flows:
- flow.append(LoginType.TERMS)
+ inserted = False
+ # m.login.terms should go near the end but before msisdn or email auth
+ for i, stage in enumerate(flow):
+ if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN:
+ flow.insert(i, LoginType.TERMS)
+ inserted = True
+ break
+ if not inserted:
+ flow.append(LoginType.TERMS)
flows.extend(new_flows)
auth_result, params, session_id = yield self.auth_handler.check_auth(
@@ -391,13 +406,6 @@ class RegisterRestServlet(RestServlet):
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
- #
- # Also check that we're not trying to register a 3pid that's already
- # been registered.
- #
- # This has probably happened in /register/email/requestToken as well,
- # but if a user hits this endpoint twice then clicks on each link from
- # the two activation emails, they would register the same 3pid twice.
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
@@ -413,17 +421,6 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existingUid = yield self.store.get_user_id_by_threepid(
- medium, address,
- )
-
- if existingUid is not None:
- raise SynapseError(
- 400,
- "%s is already in use" % medium,
- Codes.THREEPID_IN_USE,
- )
-
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
@@ -446,6 +443,28 @@ class RegisterRestServlet(RestServlet):
if auth_result:
threepid = auth_result.get(LoginType.EMAIL_IDENTITY)
+ # Also check that we're not trying to register a 3pid that's already
+ # been registered.
+ #
+ # This has probably happened in /register/email/requestToken as well,
+ # but if a user hits this endpoint twice then clicks on each link from
+ # the two activation emails, they would register the same 3pid twice.
+ for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+ if login_type in auth_result:
+ medium = auth_result[login_type]['medium']
+ address = auth_result[login_type]['address']
+
+ existingUid = yield self.store.get_user_id_by_threepid(
+ medium, address,
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400,
+ "%s is already in use" % medium,
+ Codes.THREEPID_IN_USE,
+ )
+
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
new file mode 100644
index 0000000000..41e0a44936
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -0,0 +1,338 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This class implements the proposed relation APIs from MSC 1849.
+
+Since the MSC has not been approved all APIs here are unstable and may change at
+any time to reflect changes in the MSC.
+"""
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
+
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class RelationSendServlet(RestServlet):
+ """Helper API for sending events that have relation data.
+
+ Example API shape to send a 👍 reaction to a room:
+
+ POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D
+ {}
+
+ {
+ "event_id": "$foobar"
+ }
+ """
+
+ PATTERN = (
+ "/rooms/(?P<room_id>[^/]*)/send_relation"
+ "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
+ )
+
+ def __init__(self, hs):
+ super(RelationSendServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.txns = HttpTransactionCache(hs)
+
+ def register(self, http_server):
+ http_server.register_paths(
+ "POST",
+ client_v2_patterns(self.PATTERN + "$", releases=()),
+ self.on_PUT_or_POST,
+ )
+ http_server.register_paths(
+ "PUT",
+ client_v2_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
+ self.on_PUT,
+ )
+
+ def on_PUT(self, request, *args, **kwargs):
+ return self.txns.fetch_or_execute_request(
+ request, self.on_PUT_or_POST, request, *args, **kwargs
+ )
+
+ @defer.inlineCallbacks
+ def on_PUT_or_POST(
+ self, request, room_id, parent_id, relation_type, event_type, txn_id=None
+ ):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ if event_type == EventTypes.Member:
+ # Add relations to a membership is meaningless, so we just deny it
+ # at the CS API rather than trying to handle it correctly.
+ raise SynapseError(400, "Cannot send member events with relations")
+
+ content = parse_json_object_from_request(request)
+
+ aggregation_key = parse_string(request, "key", encoding="utf-8")
+
+ content["m.relates_to"] = {
+ "event_id": parent_id,
+ "key": aggregation_key,
+ "rel_type": relation_type,
+ }
+
+ event_dict = {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict=event_dict, txn_id=txn_id
+ )
+
+ defer.returnValue((200, {"event_id": event.event_id}))
+
+
+class RelationPaginationServlet(RestServlet):
+ """API to paginate relations on an event by topological ordering, optionally
+ filtered by relation type and event type.
+ """
+
+ PATTERNS = client_v2_patterns(
+ "/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
+ "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
+ releases=(),
+ )
+
+ def __init__(self, hs):
+ super(RelationPaginationServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ yield self.auth.check_in_room_or_world_readable(
+ room_id, requester.user.to_string()
+ )
+
+ # This checks that a) the event exists and b) the user is allowed to
+ # view it.
+ yield self.event_handler.get_event(requester.user, room_id, parent_id)
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token = parse_string(request, "from")
+ to_token = parse_string(request, "to")
+
+ if from_token:
+ from_token = RelationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = RelationPaginationToken.from_string(to_token)
+
+ result = yield self.store.get_relations_for_event(
+ event_id=parent_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = yield self.store.get_events_as_list(
+ [c["event_id"] for c in result.chunk]
+ )
+
+ now = self.clock.time_msec()
+ events = yield self._event_serializer.serialize_events(events, now)
+
+ return_value = result.to_dict()
+ return_value["chunk"] = events
+
+ defer.returnValue((200, return_value))
+
+
+class RelationAggregationPaginationServlet(RestServlet):
+ """API to paginate aggregation groups of relations, e.g. paginate the
+ types and counts of the reactions on the events.
+
+ Example request and response:
+
+ GET /rooms/{room_id}/aggregations/{parent_id}
+
+ {
+ chunk: [
+ {
+ "type": "m.reaction",
+ "key": "👍",
+ "count": 3
+ }
+ ]
+ }
+ """
+
+ PATTERNS = client_v2_patterns(
+ "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
+ "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
+ releases=(),
+ )
+
+ def __init__(self, hs):
+ super(RelationAggregationPaginationServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ yield self.auth.check_in_room_or_world_readable(
+ room_id, requester.user.to_string()
+ )
+
+ # This checks that a) the event exists and b) the user is allowed to
+ # view it.
+ yield self.event_handler.get_event(requester.user, room_id, parent_id)
+
+ if relation_type not in (RelationTypes.ANNOTATION, None):
+ raise SynapseError(400, "Relation type must be 'annotation'")
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token = parse_string(request, "from")
+ to_token = parse_string(request, "to")
+
+ if from_token:
+ from_token = AggregationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = AggregationPaginationToken.from_string(to_token)
+
+ res = yield self.store.get_aggregation_groups_for_event(
+ event_id=parent_id,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ defer.returnValue((200, res.to_dict()))
+
+
+class RelationAggregationGroupPaginationServlet(RestServlet):
+ """API to paginate within an aggregation group of relations, e.g. paginate
+ all the 👍 reactions on an event.
+
+ Example request and response:
+
+ GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍
+
+ {
+ chunk: [
+ {
+ "type": "m.reaction",
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.annotation",
+ "key": "👍"
+ }
+ }
+ },
+ ...
+ ]
+ }
+ """
+
+ PATTERNS = client_v2_patterns(
+ "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
+ "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
+ releases=(),
+ )
+
+ def __init__(self, hs):
+ super(RelationAggregationGroupPaginationServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ yield self.auth.check_in_room_or_world_readable(
+ room_id, requester.user.to_string()
+ )
+
+ # This checks that a) the event exists and b) the user is allowed to
+ # view it.
+ yield self.event_handler.get_event(requester.user, room_id, parent_id)
+
+ if relation_type != RelationTypes.ANNOTATION:
+ raise SynapseError(400, "Relation type must be 'annotation'")
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token = parse_string(request, "from")
+ to_token = parse_string(request, "to")
+
+ if from_token:
+ from_token = RelationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = RelationPaginationToken.from_string(to_token)
+
+ result = yield self.store.get_relations_for_event(
+ event_id=parent_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ aggregation_key=key,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = yield self.store.get_events_as_list(
+ [c["event_id"] for c in result.chunk]
+ )
+
+ now = self.clock.time_msec()
+ events = yield self._event_serializer.serialize_events(events, now)
+
+ return_value = result.to_dict()
+ return_value["chunk"] = events
+
+ defer.returnValue((200, return_value))
+
+
+def register_servlets(hs, http_server):
+ RelationSendServlet(hs).register(http_server)
+ RelationPaginationServlet(hs).register(http_server)
+ RelationAggregationPaginationServlet(hs).register(http_server)
+ RelationAggregationGroupPaginationServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index 3db7ff8d1b..62b8de71fa 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -50,7 +50,6 @@ class RoomUpgradeRestServlet(RestServlet):
PATTERNS = client_v2_patterns(
# /rooms/$roomid/upgrade
"/rooms/(?P<room_id>[^/]*)/upgrade$",
- v2_alpha=False,
)
def __init__(self, hs):
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index a9e9a47a0b..21e9cef2d0 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -29,7 +29,6 @@ logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_v2_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
- v2_alpha=False
)
def __init__(self, hs):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 39d157a44b..c701e534e7 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -26,7 +26,6 @@ from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
- serialize_event,
)
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
@@ -86,6 +85,7 @@ class SyncRestServlet(RestServlet):
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_presence_handler()
self._server_notices_sender = hs.get_server_notices_sender()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -168,14 +168,14 @@ class SyncRestServlet(RestServlet):
)
time_now = self.clock.time_msec()
- response_content = self.encode_response(
+ response_content = yield self.encode_response(
time_now, sync_result, requester.access_token_id, filter
)
defer.returnValue((200, response_content))
- @staticmethod
- def encode_response(time_now, sync_result, access_token_id, filter):
+ @defer.inlineCallbacks
+ def encode_response(self, time_now, sync_result, access_token_id, filter):
if filter.event_format == 'client':
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == 'federation':
@@ -183,24 +183,24 @@ class SyncRestServlet(RestServlet):
else:
raise Exception("Unknown event format %s" % (filter.event_format, ))
- joined = SyncRestServlet.encode_joined(
+ joined = yield self.encode_joined(
sync_result.joined, time_now, access_token_id,
filter.event_fields,
event_formatter,
)
- invited = SyncRestServlet.encode_invited(
+ invited = yield self.encode_invited(
sync_result.invited, time_now, access_token_id,
event_formatter,
)
- archived = SyncRestServlet.encode_archived(
+ archived = yield self.encode_archived(
sync_result.archived, time_now, access_token_id,
filter.event_fields,
event_formatter,
)
- return {
+ defer.returnValue({
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
"device_lists": {
@@ -222,7 +222,7 @@ class SyncRestServlet(RestServlet):
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
- }
+ })
@staticmethod
def encode_presence(events, time_now):
@@ -239,8 +239,8 @@ class SyncRestServlet(RestServlet):
]
}
- @staticmethod
- def encode_joined(rooms, time_now, token_id, event_fields, event_formatter):
+ @defer.inlineCallbacks
+ def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the joined rooms in a sync result
@@ -261,15 +261,15 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = SyncRestServlet.encode_room(
+ joined[room.room_id] = yield self.encode_room(
room, time_now, token_id, joined=True, only_fields=event_fields,
event_formatter=event_formatter,
)
- return joined
+ defer.returnValue(joined)
- @staticmethod
- def encode_invited(rooms, time_now, token_id, event_formatter):
+ @defer.inlineCallbacks
+ def encode_invited(self, rooms, time_now, token_id, event_formatter):
"""
Encode the invited rooms in a sync result
@@ -289,7 +289,7 @@ class SyncRestServlet(RestServlet):
"""
invited = {}
for room in rooms:
- invite = serialize_event(
+ invite = yield self._event_serializer.serialize_event(
room.invite, time_now, token_id=token_id,
event_format=event_formatter,
is_invite=True,
@@ -302,10 +302,10 @@ class SyncRestServlet(RestServlet):
"invite_state": {"events": invited_state}
}
- return invited
+ defer.returnValue(invited)
- @staticmethod
- def encode_archived(rooms, time_now, token_id, event_fields, event_formatter):
+ @defer.inlineCallbacks
+ def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the archived rooms in a sync result
@@ -326,17 +326,17 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = SyncRestServlet.encode_room(
+ joined[room.room_id] = yield self.encode_room(
room, time_now, token_id, joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
)
- return joined
+ defer.returnValue(joined)
- @staticmethod
+ @defer.inlineCallbacks
def encode_room(
- room, time_now, token_id, joined,
+ self, room, time_now, token_id, joined,
only_fields, event_formatter,
):
"""
@@ -355,9 +355,10 @@ class SyncRestServlet(RestServlet):
Returns:
dict[str, object]: the room, encoded in our response format
"""
- def serialize(event):
- return serialize_event(
- event, time_now, token_id=token_id,
+ def serialize(events):
+ return self._event_serializer.serialize_events(
+ events, time_now=time_now,
+ token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,
)
@@ -376,8 +377,8 @@ class SyncRestServlet(RestServlet):
event.event_id, room.room_id, event.room_id,
)
- serialized_state = [serialize(e) for e in state_events]
- serialized_timeline = [serialize(e) for e in timeline_events]
+ serialized_state = yield serialize(state_events)
+ serialized_timeline = yield serialize(timeline_events)
account_data = room.account_data
@@ -397,7 +398,7 @@ class SyncRestServlet(RestServlet):
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
- return result
+ defer.returnValue(result)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index bdffa97805..8569677355 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -444,6 +444,9 @@ class MediaRepository(object):
)
return
+ if thumbnailer.transpose_method is not None:
+ m_width, m_height = thumbnailer.transpose()
+
if t_method == "crop":
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
@@ -578,6 +581,12 @@ class MediaRepository(object):
)
return
+ if thumbnailer.transpose_method is not None:
+ m_width, m_height = yield logcontext.defer_to_thread(
+ self.hs.get_reactor(),
+ thumbnailer.transpose
+ )
+
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails = {}
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 5aa03031f6..d90cbfb56a 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -108,6 +108,7 @@ class FileStorageProviderBackend(StorageProvider):
"""
def __init__(self, hs, config):
+ self.hs = hs
self.cache_directory = hs.config.media_store_path
self.base_directory = config
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index a4b26c2587..3efd0d80fc 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -20,6 +20,17 @@ import PIL.Image as Image
logger = logging.getLogger(__name__)
+EXIF_ORIENTATION_TAG = 0x0112
+EXIF_TRANSPOSE_MAPPINGS = {
+ 2: Image.FLIP_LEFT_RIGHT,
+ 3: Image.ROTATE_180,
+ 4: Image.FLIP_TOP_BOTTOM,
+ 5: Image.TRANSPOSE,
+ 6: Image.ROTATE_270,
+ 7: Image.TRANSVERSE,
+ 8: Image.ROTATE_90
+}
+
class Thumbnailer(object):
@@ -31,6 +42,30 @@ class Thumbnailer(object):
def __init__(self, input_path):
self.image = Image.open(input_path)
self.width, self.height = self.image.size
+ self.transpose_method = None
+ try:
+ # We don't use ImageOps.exif_transpose since it crashes with big EXIF
+ image_exif = self.image._getexif()
+ if image_exif is not None:
+ image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
+ self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
+ except Exception as e:
+ # A lot of parsing errors can happen when parsing EXIF
+ logger.info("Error parsing image EXIF information: %s", e)
+
+ def transpose(self):
+ """Transpose the image using its EXIF Orientation tag
+
+ Returns:
+ Tuple[int, int]: (width, height) containing the new image size in pixels.
+ """
+ if self.transpose_method is not None:
+ self.image = self.image.transpose(self.transpose_method)
+ self.width, self.height = self.image.size
+ self.transpose_method = None
+ # We don't need EXIF any more
+ self.image.info["exif"] = None
+ return self.image.size
def aspect(self, max_width, max_height):
"""Calculate the largest size that preserves aspect ratio which
diff --git a/synapse/server.py b/synapse/server.py
index 8c30ac2fa5..80d40b9272 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -35,6 +35,7 @@ from synapse.crypto import context_factory
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker
+from synapse.events.utils import EventClientSerializer
from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import (
FederationHandlerRegistry,
@@ -185,6 +186,7 @@ class HomeServer(object):
'sendmail',
'registration_handler',
'account_validity_handler',
+ 'event_client_serializer',
]
REQUIRED_ON_MASTER_STARTUP = [
@@ -511,6 +513,9 @@ class HomeServer(object):
def build_account_validity_handler(self):
return AccountValidityHandler(self)
+ def build_event_client_serializer(self):
+ return EventClientSerializer(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index c432041b4e..7522d3fd57 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -49,6 +49,7 @@ from .pusher import PusherStore
from .receipts import ReceiptsStore
from .registration import RegistrationStore
from .rejections import RejectionsStore
+from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
@@ -99,6 +100,7 @@ class DataStore(
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
+ RelationsStore,
):
def __init__(self, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 6092f600ba..eb329ebd8b 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -302,7 +302,7 @@ class ApplicationServiceTransactionWorkerStore(
event_ids = json.loads(entry["event_ids"])
- events = yield self._get_events(event_ids)
+ events = yield self.get_events_as_list(event_ids)
defer.returnValue(
AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
@@ -358,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
- events = yield self._get_events(event_ids)
+ events = yield self.get_events_as_list(event_ids)
defer.returnValue((upper_bound, events))
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 956f876572..09e39c2c28 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -45,7 +45,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
return self.get_auth_chain_ids(
event_ids, include_given=include_given
- ).addCallback(self._get_events)
+ ).addCallback(self.get_events_as_list)
def get_auth_chain_ids(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
@@ -316,7 +316,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_list,
limit,
)
- .addCallback(self._get_events)
+ .addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
)
@@ -382,7 +382,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
- events = yield self._get_events(ids)
+ events = yield self.get_events_as_list(ids)
defer.returnValue(events)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 7a7f841c6c..881d6d0126 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1325,6 +1325,9 @@ class EventsStore(
txn, event.room_id, event.redacts
)
+ # Remove from relations table.
+ self._handle_redaction(txn, event.redacts)
+
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
@@ -1351,6 +1354,8 @@ class EventsStore(
# Insert into the event_search table.
self._store_guest_access_txn(txn, event)
+ self._handle_event_relations(txn, event)
+
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@@ -1655,10 +1660,11 @@ class EventsStore(
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
@@ -1673,11 +1679,12 @@ class EventsStore(
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
@@ -1698,10 +1705,11 @@ class EventsStore(
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
@@ -1716,11 +1724,12 @@ class EventsStore(
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 663991a9b6..adc6cf26b5 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -103,7 +103,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred : A FrozenEvent.
"""
- events = yield self._get_events(
+ events = yield self.get_events_as_list(
[event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@@ -142,7 +142,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred : Dict from event_id to event.
"""
- events = yield self._get_events(
+ events = yield self.get_events_as_list(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@@ -152,13 +152,32 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@defer.inlineCallbacks
- def _get_events(
+ def get_events_as_list(
self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
):
+ """Get events from the database and return in a list in the same order
+ as given by `event_ids` arg.
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred[list[EventBase]]: List of events fetched from the database. The
+ events are in the same order as `event_ids` arg.
+
+ Note that the returned list may be smaller than the list of event
+ IDs if not all events could be fetched.
+ """
+
if not event_ids:
defer.returnValue([])
@@ -202,21 +221,22 @@ class EventsWorkerStore(SQLBaseStore):
#
# The problem is that we end up at this point when an event
# which has been redacted is pulled out of the database by
- # _enqueue_events, because _enqueue_events needs to check the
- # redaction before it can cache the redacted event. So obviously,
- # calling get_event to get the redacted event out of the database
- # gives us an infinite loop.
+ # _enqueue_events, because _enqueue_events needs to check
+ # the redaction before it can cache the redacted event. So
+ # obviously, calling get_event to get the redacted event out
+ # of the database gives us an infinite loop.
#
- # For now (quick hack to fix during 0.99 release cycle), we just
- # go and fetch the relevant row from the db, but it would be nice
- # to think about how we can cache this rather than hit the db
- # every time we access a redaction event.
+ # For now (quick hack to fix during 0.99 release cycle), we
+ # just go and fetch the relevant row from the db, but it
+ # would be nice to think about how we can cache this rather
+ # than hit the db every time we access a redaction event.
#
# One thought on how to do this:
- # 1. split _get_events up so that it is divided into (a) get the
- # rawish event from the db/cache, (b) do the redaction/rejection
- # filtering
- # 2. have _get_event_from_row just call the first half of that
+ # 1. split get_events_as_list up so that it is divided into
+ # (a) get the rawish event from the db/cache, (b) do the
+ # redaction/rejection filtering
+ # 2. have _get_event_from_row just call the first half of
+ # that
orig_sender = yield self._simple_select_one_onecol(
table="events",
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
new file mode 100644
index 0000000000..493abe405e
--- /dev/null
+++ b/synapse/storage/relations.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+
+from twisted.internet import defer
+
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.stream import generate_pagination_where_clause
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s
+class PaginationChunk(object):
+ """Returned by relation pagination APIs.
+
+ Attributes:
+ chunk (list): The rows returned by pagination
+ next_batch (Any|None): Token to fetch next set of results with, if
+ None then there are no more results.
+ prev_batch (Any|None): Token to fetch previous set of results with, if
+ None then there are no previous results.
+ """
+
+ chunk = attr.ib()
+ next_batch = attr.ib(default=None)
+ prev_batch = attr.ib(default=None)
+
+ def to_dict(self):
+ d = {"chunk": self.chunk}
+
+ if self.next_batch:
+ d["next_batch"] = self.next_batch.to_string()
+
+ if self.prev_batch:
+ d["prev_batch"] = self.prev_batch.to_string()
+
+ return d
+
+
+@attr.s(frozen=True, slots=True)
+class RelationPaginationToken(object):
+ """Pagination token for relation pagination API.
+
+ As the results are order by topological ordering, we can use the
+ `topological_ordering` and `stream_ordering` fields of the events at the
+ boundaries of the chunk as pagination tokens.
+
+ Attributes:
+ topological (int): The topological ordering of the boundary event
+ stream (int): The stream ordering of the boundary event.
+ """
+
+ topological = attr.ib()
+ stream = attr.ib()
+
+ @staticmethod
+ def from_string(string):
+ try:
+ t, s = string.split("-")
+ return RelationPaginationToken(int(t), int(s))
+ except ValueError:
+ raise SynapseError(400, "Invalid token")
+
+ def to_string(self):
+ return "%d-%d" % (self.topological, self.stream)
+
+ def as_tuple(self):
+ return attr.astuple(self)
+
+
+@attr.s(frozen=True, slots=True)
+class AggregationPaginationToken(object):
+ """Pagination token for relation aggregation pagination API.
+
+ As the results are order by count and then MAX(stream_ordering) of the
+ aggregation groups, we can just use them as our pagination token.
+
+ Attributes:
+ count (int): The count of relations in the boundar group.
+ stream (int): The MAX stream ordering in the boundary group.
+ """
+
+ count = attr.ib()
+ stream = attr.ib()
+
+ @staticmethod
+ def from_string(string):
+ try:
+ c, s = string.split("-")
+ return AggregationPaginationToken(int(c), int(s))
+ except ValueError:
+ raise SynapseError(400, "Invalid token")
+
+ def to_string(self):
+ return "%d-%d" % (self.count, self.stream)
+
+ def as_tuple(self):
+ return attr.astuple(self)
+
+
+class RelationsWorkerStore(SQLBaseStore):
+ @cached(tree=True)
+ def get_relations_for_event(
+ self,
+ event_id,
+ relation_type=None,
+ event_type=None,
+ aggregation_key=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of relations for an event, ordered by topological ordering.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ relation_type (str|None): Only fetch events with this relation
+ type, if given.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ aggregation_key (str|None): Only fetch events with this aggregation
+ key, if given.
+ limit (int): Only fetch the most recent `limit` events.
+ direction (str): Whether to fetch the most recent first (`"b"`) or
+ the oldest first (`"f"`).
+ from_token (RelationPaginationToken|None): Fetch rows from the given
+ token, or from the start if None.
+ to_token (RelationPaginationToken|None): Fetch rows up to the given
+ token, or up to the end if None.
+
+ Returns:
+ Deferred[PaginationChunk]: List of event IDs that match relations
+ requested. The rows are of the form `{"event_id": "..."}`.
+ """
+
+ where_clause = ["relates_to_id = ?"]
+ where_args = [event_id]
+
+ if relation_type is not None:
+ where_clause.append("relation_type = ?")
+ where_args.append(relation_type)
+
+ if event_type is not None:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ if aggregation_key:
+ where_clause.append("aggregation_key = ?")
+ where_args.append(aggregation_key)
+
+ pagination_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if pagination_clause:
+ where_clause.append(pagination_clause)
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ sql = """
+ SELECT event_id, topological_ordering, stream_ordering
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
+
+ def _get_recent_references_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ last_topo_id = None
+ last_stream_id = None
+ events = []
+ for row in txn:
+ events.append({"event_id": row[0]})
+ last_topo_id = row[1]
+ last_stream_id = row[2]
+
+ next_batch = None
+ if len(events) > limit and last_topo_id and last_stream_id:
+ next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.runInteraction(
+ "get_recent_references_for_event", _get_recent_references_for_event_txn
+ )
+
+ @cached(tree=True)
+ def get_aggregation_groups_for_event(
+ self,
+ event_id,
+ event_type=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of annotations on the event, grouped by event type and
+ aggregation key, sorted by count.
+
+ This is used e.g. to get the what and how many reactions have happend
+ on an event.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ limit (int): Only fetch the `limit` groups.
+ direction (str): Whether to fetch the highest count first (`"b"`) or
+ the lowest count first (`"f"`).
+ from_token (AggregationPaginationToken|None): Fetch rows from the
+ given token, or from the start if None.
+ to_token (AggregationPaginationToken|None): Fetch rows up to the
+ given token, or up to the end if None.
+
+
+ Returns:
+ Deferred[PaginationChunk]: List of groups of annotations that
+ match. Each row is a dict with `type`, `key` and `count` fields.
+ """
+
+ where_clause = ["relates_to_id = ?", "relation_type = ?"]
+ where_args = [event_id, RelationTypes.ANNOTATION]
+
+ if event_type:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ having_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("COUNT(*)", "MAX(stream_ordering)"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ if having_clause:
+ having_clause = "HAVING " + having_clause
+ else:
+ having_clause = ""
+
+ sql = """
+ SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE {where_clause}
+ GROUP BY relation_type, type, aggregation_key
+ {having_clause}
+ ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ LIMIT ?
+ """.format(
+ where_clause=" AND ".join(where_clause),
+ order=order,
+ having_clause=having_clause,
+ )
+
+ def _get_aggregation_groups_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ next_batch = None
+ events = []
+ for row in txn:
+ events.append({"type": row[0], "key": row[1], "count": row[2]})
+ next_batch = AggregationPaginationToken(row[2], row[3])
+
+ if len(events) <= limit:
+ next_batch = None
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.runInteraction(
+ "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+ )
+
+ @cachedInlineCallbacks()
+ def get_applicable_edit(self, event_id):
+ """Get the most recent edit (if any) that has happened for the given
+ event.
+
+ Correctly handles checking whether edits were allowed to happen.
+
+ Args:
+ event_id (str): The original event ID
+
+ Returns:
+ Deferred[EventBase|None]: Returns the most recent edit, if any.
+ """
+
+ # We only allow edits for `m.room.message` events that have the same sender
+ # and event type. We can't assert these things during regular event auth so
+ # we have to do the checks post hoc.
+
+ # Fetches latest edit that has the same type and sender as the
+ # original, and is an `m.room.message`.
+ sql = """
+ SELECT edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by edit.origin_server_ts DESC, edit.event_id DESC
+ LIMIT 1
+ """
+
+ def _get_applicable_edit_txn(txn):
+ txn.execute(
+ sql, (event_id, RelationTypes.REPLACE,)
+ )
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ edit_id = yield self.runInteraction(
+ "get_applicable_edit", _get_applicable_edit_txn
+ )
+
+ if not edit_id:
+ return
+
+ edit_event = yield self.get_event(edit_id, allow_none=True)
+ defer.returnValue(edit_event)
+
+
+class RelationsStore(RelationsWorkerStore):
+ def _handle_event_relations(self, txn, event):
+ """Handles inserting relation data during peristence of events
+
+ Args:
+ txn
+ event (EventBase)
+ """
+ relation = event.content.get("m.relates_to")
+ if not relation:
+ # No relations
+ return
+
+ rel_type = relation.get("rel_type")
+ if rel_type not in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
+ # Unknown relation type
+ return
+
+ parent_id = relation.get("event_id")
+ if not parent_id:
+ # Invalid relation
+ return
+
+ aggregation_key = relation.get("key")
+
+ self._simple_insert_txn(
+ txn,
+ table="event_relations",
+ values={
+ "event_id": event.event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ "aggregation_key": aggregation_key,
+ },
+ )
+
+ txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
+ txn.call_after(
+ self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+ )
+
+ if rel_type == RelationTypes.REPLACE:
+ txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
+
+ def _handle_redaction(self, txn, redacted_event_id):
+ """Handles receiving a redaction and checking whether we need to remove
+ any redacted relations from the database.
+
+ Args:
+ txn
+ redacted_event_id (str): The event that was redacted.
+ """
+
+ self._simple_delete_txn(
+ txn,
+ table="event_relations",
+ keyvalues={
+ "event_id": redacted_event_id,
+ }
+ )
diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql
new file mode 100644
index 0000000000..134862b870
--- /dev/null
+++ b/synapse/storage/schema/delta/54/relations.sql
@@ -0,0 +1,27 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks related events, like reactions, replies, edits, etc. Note that things
+-- in this table are not necessarily "valid", e.g. it may contain edits from
+-- people who don't have power to edit other peoples events.
+CREATE TABLE IF NOT EXISTS event_relations (
+ event_id TEXT NOT NULL,
+ relates_to_id TEXT NOT NULL,
+ relation_type TEXT NOT NULL,
+ aggregation_key TEXT
+);
+
+CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id);
+CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 23f20a5a3a..49a7e4dec3 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -460,7 +460,7 @@ class SearchStore(BackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self._get_events([r["event_id"] for r in results])
+ events = yield self.get_events_as_list([r["event_id"] for r in results])
event_map = {ev.event_id: ev for ev in events}
@@ -605,7 +605,7 @@ class SearchStore(BackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self._get_events([r["event_id"] for r in results])
+ events = yield self.get_events_as_list([r["event_id"] for r in results])
event_map = {ev.event_id: ev for ev in events}
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 9cd1e0f9fe..529ad4ea79 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -64,59 +64,135 @@ _EventDictReturn = namedtuple(
)
-def lower_bound(token, engine, inclusive=False):
- inclusive = "=" if inclusive else ""
- if token.topological is None:
- return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
- else:
- if isinstance(engine, PostgresEngine):
- # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
- # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
- # use the later form when running against postgres.
- return "((%d,%d) <%s (%s,%s))" % (
- token.topological,
- token.stream,
- inclusive,
- "topological_ordering",
- "stream_ordering",
+def generate_pagination_where_clause(
+ direction, column_names, from_token, to_token, engine,
+):
+ """Creates an SQL expression to bound the columns by the pagination
+ tokens.
+
+ For example creates an SQL expression like:
+
+ (6, 7) >= (topological_ordering, stream_ordering)
+ AND (5, 3) < (topological_ordering, stream_ordering)
+
+ would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3).
+
+ Note that tokens are considered to be after the row they are in, e.g. if
+ a row A has a token T, then we consider A to be before T. This convention
+ is important when figuring out inequalities for the generated SQL, and
+ produces the following result:
+ - If paginating forwards then we exclude any rows matching the from
+ token, but include those that match the to token.
+ - If paginating backwards then we include any rows matching the from
+ token, but include those that match the to token.
+
+ Args:
+ direction (str): Whether we're paginating backwards("b") or
+ forwards ("f").
+ column_names (tuple[str, str]): The column names to bound. Must *not*
+ be user defined as these get inserted directly into the SQL
+ statement without escapes.
+ from_token (tuple[int, int]|None): The start point for the pagination.
+ This is an exclusive minimum bound if direction is "f", and an
+ inclusive maximum bound if direction is "b".
+ to_token (tuple[int, int]|None): The endpoint point for the pagination.
+ This is an inclusive maximum bound if direction is "f", and an
+ exclusive minimum bound if direction is "b".
+ engine: The database engine to generate the clauses for
+
+ Returns:
+ str: The sql expression
+ """
+ assert direction in ("b", "f")
+
+ where_clause = []
+ if from_token:
+ where_clause.append(
+ _make_generic_sql_bound(
+ bound=">=" if direction == "b" else "<",
+ column_names=column_names,
+ values=from_token,
+ engine=engine,
)
- return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
- token.topological,
- "topological_ordering",
- token.topological,
- "topological_ordering",
- token.stream,
- inclusive,
- "stream_ordering",
- )
-
-
-def upper_bound(token, engine, inclusive=True):
- inclusive = "=" if inclusive else ""
- if token.topological is None:
- return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
- else:
- if isinstance(engine, PostgresEngine):
- # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
- # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
- # use the later form when running against postgres.
- return "((%d,%d) >%s (%s,%s))" % (
- token.topological,
- token.stream,
- inclusive,
- "topological_ordering",
- "stream_ordering",
+ )
+
+ if to_token:
+ where_clause.append(
+ _make_generic_sql_bound(
+ bound="<" if direction == "b" else ">=",
+ column_names=column_names,
+ values=to_token,
+ engine=engine,
)
- return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
- token.topological,
- "topological_ordering",
- token.topological,
- "topological_ordering",
- token.stream,
- inclusive,
- "stream_ordering",
)
+ return " AND ".join(where_clause)
+
+
+def _make_generic_sql_bound(bound, column_names, values, engine):
+ """Create an SQL expression that bounds the given column names by the
+ values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
+
+ Only works with two columns.
+
+ Older versions of SQLite don't support that syntax so we have to expand it
+ out manually.
+
+ Args:
+ bound (str): The comparison operator to use. One of ">", "<", ">=",
+ "<=", where the values are on the left and columns on the right.
+ names (tuple[str, str]): The column names. Must *not* be user defined
+ as these get inserted directly into the SQL statement without
+ escapes.
+ values (tuple[int|None, int]): The values to bound the columns by. If
+ the first value is None then only creates a bound on the second
+ column.
+ engine: The database engine to generate the SQL for
+
+ Returns:
+ str
+ """
+
+ assert(bound in (">", "<", ">=", "<="))
+
+ name1, name2 = column_names
+ val1, val2 = values
+
+ if val1 is None:
+ val2 = int(val2)
+ return "(%d %s %s)" % (val2, bound, name2)
+
+ val1 = int(val1)
+ val2 = int(val2)
+
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
+ # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) %s (%s,%s))" % (
+ val1, val2,
+ bound,
+ name1, name2,
+ )
+
+ # We want to generate queries of e.g. the form:
+ #
+ # (val1 < name1 OR (val1 = name1 AND val2 <= name2))
+ #
+ # which is equivalent to (val1, val2) < (name1, name2)
+
+ return """(
+ {val1:d} {strict_bound} {name1}
+ OR ({val1:d} = {name1} AND {val2:d} {bound} {name2})
+ )""".format(
+ name1=name1,
+ val1=val1,
+ name2=name2,
+ val2=val2,
+ strict_bound=bound[0], # The first bound must always be strict equality here
+ bound=bound,
+ )
+
def filter_to_clause(event_filter):
# NB: This may create SQL clauses that don't optimise well (and we don't
@@ -319,7 +395,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
+ ret = yield self.get_events_as_list([
+ r.event_id for r in rows], get_prev_content=True,
+ )
self._set_before_and_after(ret, rows, topo_order=from_id is None)
@@ -367,7 +445,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_membership_changes_for_user", f)
- ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
+ ret = yield self.get_events_as_list(
+ [r.event_id for r in rows], get_prev_content=True,
+ )
self._set_before_and_after(ret, rows, topo_order=False)
@@ -394,7 +474,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
logger.debug("stream before")
- events = yield self._get_events(
+ events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
logger.debug("stream after")
@@ -580,11 +660,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events_before = yield self._get_events(
+ events_before = yield self.get_events_as_list(
[e for e in results["before"]["event_ids"]], get_prev_content=True
)
- events_after = yield self._get_events(
+ events_after = yield self.get_events_as_list(
[e for e in results["after"]["event_ids"]], get_prev_content=True
)
@@ -697,7 +777,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_all_new_events_stream", get_all_new_events_stream_txn
)
- events = yield self._get_events(event_ids)
+ events = yield self.get_events_as_list(event_ids)
defer.returnValue((upper_bound, events))
@@ -758,20 +838,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
- bounds = upper_bound(from_token, self.database_engine)
- if to_token:
- bounds = "%s AND %s" % (
- bounds,
- lower_bound(to_token, self.database_engine),
- )
else:
order = "ASC"
- bounds = lower_bound(from_token, self.database_engine)
- if to_token:
- bounds = "%s AND %s" % (
- bounds,
- upper_bound(to_token, self.database_engine),
- )
+
+ bounds = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=from_token,
+ to_token=to_token,
+ engine=self.database_engine,
+ )
filter_clause, filter_args = filter_to_clause(event_filter)
@@ -849,7 +925,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events = yield self._get_events(
+ events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 2f16f23d91..7253ba120f 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -156,6 +156,25 @@ def concurrently_execute(func, args, limit):
], consumeErrors=True)).addErrback(unwrapFirstError)
+def yieldable_gather_results(func, iter, *args, **kwargs):
+ """Executes the function with each argument concurrently.
+
+ Args:
+ func (func): Function to execute that returns a Deferred
+ iter (iter): An iterable that yields items that get passed as the first
+ argument to the function
+ *args: Arguments to be passed to each call to func
+
+ Returns
+ Deferred[list]: Resolved when all functions have been invoked, or errors if
+ one of the function calls fails.
+ """
+ return logcontext.make_deferred_yieldable(defer.gatherResults([
+ run_in_background(func, item, *args, **kwargs)
+ for item in iter
+ ], consumeErrors=True)).addErrback(unwrapFirstError)
+
+
class Linearizer(object):
"""Limits concurrent access to resources based on a key. Useful to ensure
only a few things happen at a time on a given resource.
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 194da87639..e14c8bdfda 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -27,6 +27,7 @@ def user_left_room(distributor, user, room_id):
distributor.fire("user_left_room", user=user, room_id=room_id)
+# XXX: this is no longer used. We should probably kill it.
def user_joined_room(distributor, user, room_id):
distributor.fire("user_joined_room", user=user, room_id=room_id)
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 7deb38f2a7..b146d137f4 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -30,31 +30,14 @@ logger = logging.getLogger(__name__)
class FederationRateLimiter(object):
- def __init__(self, clock, window_size, sleep_limit, sleep_msec,
- reject_limit, concurrent_requests):
+ def __init__(self, clock, config):
"""
Args:
clock (Clock)
- window_size (int): The window size in milliseconds.
- sleep_limit (int): The number of requests received in the last
- `window_size` milliseconds before we artificially start
- delaying processing of requests.
- sleep_msec (int): The number of milliseconds to delay processing
- of incoming requests by.
- reject_limit (int): The maximum number of requests that are can be
- queued for processing before we start rejecting requests with
- a 429 Too Many Requests response.
- concurrent_requests (int): The number of concurrent requests to
- process.
+ config (FederationRateLimitConfig)
"""
self.clock = clock
-
- self.window_size = window_size
- self.sleep_limit = sleep_limit
- self.sleep_msec = sleep_msec
- self.reject_limit = reject_limit
- self.concurrent_requests = concurrent_requests
-
+ self._config = config
self.ratelimiters = {}
def ratelimit(self, host):
@@ -76,25 +59,25 @@ class FederationRateLimiter(object):
host,
_PerHostRatelimiter(
clock=self.clock,
- window_size=self.window_size,
- sleep_limit=self.sleep_limit,
- sleep_msec=self.sleep_msec,
- reject_limit=self.reject_limit,
- concurrent_requests=self.concurrent_requests,
+ config=self._config,
)
).ratelimit()
class _PerHostRatelimiter(object):
- def __init__(self, clock, window_size, sleep_limit, sleep_msec,
- reject_limit, concurrent_requests):
+ def __init__(self, clock, config):
+ """
+ Args:
+ clock (Clock)
+ config (FederationRateLimitConfig)
+ """
self.clock = clock
- self.window_size = window_size
- self.sleep_limit = sleep_limit
- self.sleep_sec = sleep_msec / 1000.0
- self.reject_limit = reject_limit
- self.concurrent_requests = concurrent_requests
+ self.window_size = config.window_size
+ self.sleep_limit = config.sleep_limit
+ self.sleep_sec = config.sleep_delay / 1000.0
+ self.reject_limit = config.reject_limit
+ self.concurrent_requests = config.concurrent
# request_id objects for requests which have been slept
self.sleeping_requests = set()
|