diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index a342a0e0da..37e31d2b6f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
-from synapse.types import UserID
+from synapse.types import UserID, ClientInfo
import logging
@@ -102,6 +102,8 @@ class Auth(object):
def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id)
+ logger.debug("Got curr_state %s", curr_state)
+
for event in curr_state:
if event.type == EventTypes.Member:
try:
@@ -290,7 +292,9 @@ class Auth(object):
Args:
request - An HTTP request with an access_token query parameter.
Returns:
- UserID : User ID object of the user making the request
+ tuple : of UserID and device string:
+ User ID object of the user making the request
+ Client ID object of the client instance the user is using
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@@ -299,6 +303,8 @@ class Auth(object):
access_token = request.args["access_token"][0]
user_info = yield self.get_user_by_token(access_token)
user = user_info["user"]
+ device_id = user_info["device_id"]
+ token_id = user_info["token_id"]
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
@@ -314,7 +320,7 @@ class Auth(object):
user_agent=user_agent
)
- defer.returnValue(user)
+ defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError:
raise AuthError(403, "Missing access token.")
@@ -339,6 +345,7 @@ class Auth(object):
"admin": bool(ret.get("admin", False)),
"device_id": ret.get("device_id"),
"user": UserID.from_string(ret.get("name")),
+ "token_id": ret.get("token_id", None),
}
defer.returnValue(user_info)
@@ -353,9 +360,23 @@ class Auth(object):
def add_auth_events(self, builder, context):
yield run_on_reactor()
- if builder.type == EventTypes.Create:
- builder.auth_events = []
- return
+ auth_ids = self.compute_auth_events(builder, context)
+
+ auth_events_entries = yield self.store.add_event_hashes(
+ auth_ids
+ )
+
+ builder.auth_events = auth_events_entries
+
+ context.auth_events = {
+ k: v
+ for k, v in context.current_state.items()
+ if v.event_id in auth_ids
+ }
+
+ def compute_auth_events(self, event, context):
+ if event.type == EventTypes.Create:
+ return []
auth_ids = []
@@ -368,7 +389,7 @@ class Auth(object):
key = (EventTypes.JoinRules, "", )
join_rule_event = context.current_state.get(key)
- key = (EventTypes.Member, builder.user_id, )
+ key = (EventTypes.Member, event.user_id, )
member_event = context.current_state.get(key)
key = (EventTypes.Create, "", )
@@ -382,8 +403,8 @@ class Auth(object):
else:
is_public = False
- if builder.type == EventTypes.Member:
- e_type = builder.content["membership"]
+ if event.type == EventTypes.Member:
+ e_type = event.content["membership"]
if e_type in [Membership.JOIN, Membership.INVITE]:
if join_rule_event:
auth_ids.append(join_rule_event.event_id)
@@ -398,17 +419,7 @@ class Auth(object):
if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id)
- auth_events_entries = yield self.store.add_event_hashes(
- auth_ids
- )
-
- builder.auth_events = auth_events_entries
-
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
+ return auth_ids
@log_function
def _can_send_event(self, event, auth_events):
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 7ee6dcc46e..0d3fc629af 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -74,3 +74,9 @@ class EventTypes(object):
Message = "m.room.message"
Topic = "m.room.topic"
Name = "m.room.name"
+
+
+class RejectedReason(object):
+ AUTH_ERROR = "auth_error"
+ REPLACED = "replaced"
+ NOT_ANCESTOR = "not_ancestor"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 2b049debf3..ad478aa6b7 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -21,6 +21,7 @@ logger = logging.getLogger(__name__)
class Codes(object):
+ UNRECOGNIZED = "M_UNRECOGNIZED"
UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN"
BAD_JSON = "M_BAD_JSON"
@@ -34,6 +35,7 @@ class Codes(object):
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
+ MISSING_PARAM = "M_MISSING_PARAM",
TOO_LARGE = "M_TOO_LARGE"
@@ -81,6 +83,35 @@ class RegistrationError(SynapseError):
pass
+class UnrecognizedRequestError(SynapseError):
+ """An error indicating we don't understand the request you're trying to make"""
+ def __init__(self, *args, **kwargs):
+ if "errcode" not in kwargs:
+ kwargs["errcode"] = Codes.UNRECOGNIZED
+ message = None
+ if len(args) == 0:
+ message = "Unrecognized request"
+ else:
+ message = args[0]
+ super(UnrecognizedRequestError, self).__init__(
+ 400,
+ message,
+ **kwargs
+ )
+
+
+class NotFoundError(SynapseError):
+ """An error indicating we can't find the thing you asked for"""
+ def __init__(self, *args, **kwargs):
+ if "errcode" not in kwargs:
+ kwargs["errcode"] = Codes.NOT_FOUND
+ super(NotFoundError, self).__init__(
+ 404,
+ "Not found",
+ **kwargs
+ )
+
+
class AuthError(SynapseError):
"""An error raised when there was a problem authorising an event."""
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
new file mode 100644
index 0000000000..4d570b74f8
--- /dev/null
+++ b/synapse/api/filtering.py
@@ -0,0 +1,229 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.api.errors import SynapseError
+from synapse.types import UserID, RoomID
+
+
+class Filtering(object):
+
+ def __init__(self, hs):
+ super(Filtering, self).__init__()
+ self.store = hs.get_datastore()
+
+ def get_user_filter(self, user_localpart, filter_id):
+ result = self.store.get_user_filter(user_localpart, filter_id)
+ result.addCallback(Filter)
+ return result
+
+ def add_user_filter(self, user_localpart, user_filter):
+ self._check_valid_filter(user_filter)
+ return self.store.add_user_filter(user_localpart, user_filter)
+
+ # TODO(paul): surely we should probably add a delete_user_filter or
+ # replace_user_filter at some point? There's no REST API specified for
+ # them however
+
+ def _check_valid_filter(self, user_filter_json):
+ """Check if the provided filter is valid.
+
+ This inspects all definitions contained within the filter.
+
+ Args:
+ user_filter_json(dict): The filter
+ Raises:
+ SynapseError: If the filter is not valid.
+ """
+ # NB: Filters are the complete json blobs. "Definitions" are an
+ # individual top-level key e.g. public_user_data. Filters are made of
+ # many definitions.
+
+ top_level_definitions = [
+ "public_user_data", "private_user_data", "server_data"
+ ]
+
+ room_level_definitions = [
+ "state", "events", "ephemeral"
+ ]
+
+ for key in top_level_definitions:
+ if key in user_filter_json:
+ self._check_definition(user_filter_json[key])
+
+ if "room" in user_filter_json:
+ for key in room_level_definitions:
+ if key in user_filter_json["room"]:
+ self._check_definition(user_filter_json["room"][key])
+
+ def _check_definition(self, definition):
+ """Check if the provided definition is valid.
+
+ This inspects not only the types but also the values to make sure they
+ make sense.
+
+ Args:
+ definition(dict): The filter definition
+ Raises:
+ SynapseError: If there was a problem with this definition.
+ """
+ # NB: Filters are the complete json blobs. "Definitions" are an
+ # individual top-level key e.g. public_user_data. Filters are made of
+ # many definitions.
+ if type(definition) != dict:
+ raise SynapseError(
+ 400, "Expected JSON object, not %s" % (definition,)
+ )
+
+ # check rooms are valid room IDs
+ room_id_keys = ["rooms", "not_rooms"]
+ for key in room_id_keys:
+ if key in definition:
+ if type(definition[key]) != list:
+ raise SynapseError(400, "Expected %s to be a list." % key)
+ for room_id in definition[key]:
+ RoomID.from_string(room_id)
+
+ # check senders are valid user IDs
+ user_id_keys = ["senders", "not_senders"]
+ for key in user_id_keys:
+ if key in definition:
+ if type(definition[key]) != list:
+ raise SynapseError(400, "Expected %s to be a list." % key)
+ for user_id in definition[key]:
+ UserID.from_string(user_id)
+
+ # TODO: We don't limit event type values but we probably should...
+ # check types are valid event types
+ event_keys = ["types", "not_types"]
+ for key in event_keys:
+ if key in definition:
+ if type(definition[key]) != list:
+ raise SynapseError(400, "Expected %s to be a list." % key)
+ for event_type in definition[key]:
+ if not isinstance(event_type, basestring):
+ raise SynapseError(400, "Event type should be a string")
+
+ if "format" in definition:
+ event_format = definition["format"]
+ if event_format not in ["federation", "events"]:
+ raise SynapseError(400, "Invalid format: %s" % (event_format,))
+
+ if "select" in definition:
+ event_select_list = definition["select"]
+ for select_key in event_select_list:
+ if select_key not in ["event_id", "origin_server_ts",
+ "thread_id", "content", "content.body"]:
+ raise SynapseError(400, "Bad select: %s" % (select_key,))
+
+ if ("bundle_updates" in definition and
+ type(definition["bundle_updates"]) != bool):
+ raise SynapseError(400, "Bad bundle_updates: expected bool.")
+
+
+class Filter(object):
+ def __init__(self, filter_json):
+ self.filter_json = filter_json
+
+ def filter_public_user_data(self, events):
+ return self._filter_on_key(events, ["public_user_data"])
+
+ def filter_private_user_data(self, events):
+ return self._filter_on_key(events, ["private_user_data"])
+
+ def filter_room_state(self, events):
+ return self._filter_on_key(events, ["room", "state"])
+
+ def filter_room_events(self, events):
+ return self._filter_on_key(events, ["room", "events"])
+
+ def filter_room_ephemeral(self, events):
+ return self._filter_on_key(events, ["room", "ephemeral"])
+
+ def _filter_on_key(self, events, keys):
+ filter_json = self.filter_json
+ if not filter_json:
+ return events
+
+ try:
+ # extract the right definition from the filter
+ definition = filter_json
+ for key in keys:
+ definition = definition[key]
+ return self._filter_with_definition(events, definition)
+ except KeyError:
+ # return all events if definition isn't specified.
+ return events
+
+ def _filter_with_definition(self, events, definition):
+ return [e for e in events if self._passes_definition(definition, e)]
+
+ def _passes_definition(self, definition, event):
+ """Check if the event passes through the given definition.
+
+ Args:
+ definition(dict): The definition to check against.
+ event(Event): The event to check.
+ Returns:
+ True if the event passes through the filter.
+ """
+ # Algorithm notes:
+ # For each key in the definition, check the event meets the criteria:
+ # * For types: Literal match or prefix match (if ends with wildcard)
+ # * For senders/rooms: Literal match only
+ # * "not_" checks take presedence (e.g. if "m.*" is in both 'types'
+ # and 'not_types' then it is treated as only being in 'not_types')
+
+ # room checks
+ if hasattr(event, "room_id"):
+ room_id = event.room_id
+ allow_rooms = definition.get("rooms", None)
+ reject_rooms = definition.get("not_rooms", None)
+ if reject_rooms and room_id in reject_rooms:
+ return False
+ if allow_rooms and room_id not in allow_rooms:
+ return False
+
+ # sender checks
+ if hasattr(event, "sender"):
+ # Should we be including event.state_key for some event types?
+ sender = event.sender
+ allow_senders = definition.get("senders", None)
+ reject_senders = definition.get("not_senders", None)
+ if reject_senders and sender in reject_senders:
+ return False
+ if allow_senders and sender not in allow_senders:
+ return False
+
+ # type checks
+ if "not_types" in definition:
+ for def_type in definition["not_types"]:
+ if self._event_matches_type(event, def_type):
+ return False
+ if "types" in definition:
+ included = False
+ for def_type in definition["types"]:
+ if self._event_matches_type(event, def_type):
+ included = True
+ break
+ if not included:
+ return False
+
+ return True
+
+ def _event_matches_type(self, event, def_type):
+ if def_type.endswith("*"):
+ type_prefix = def_type[:-1]
+ return event.type.startswith(type_prefix)
+ else:
+ return event.type == def_type
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index f7e24d0cc6..a9397de5b2 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -277,6 +277,8 @@ def setup():
bind_port = None
hs.start_listening(bind_port, config.unsecure_port)
+ hs.get_pusherpool().start()
+
if config.daemonize:
print config.pid_file
daemon = Daemonize(
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 9c910fa3fc..cdb6279764 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -61,9 +61,11 @@ class SynapseKeyClientProtocol(HTTPClient):
def __init__(self):
self.remote_key = defer.Deferred()
+ self.host = None
def connectionMade(self):
- logger.debug("Connected to %s", self.transport.getHost())
+ self.host = self.transport.getHost()
+ logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", b"/_matrix/key/v1/")
self.endHeaders()
self.timer = reactor.callLater(
@@ -92,8 +94,7 @@ class SynapseKeyClientProtocol(HTTPClient):
self.timer.cancel()
def on_timeout(self):
- logger.debug("Timeout waiting for response from %s",
- self.transport.getHost())
+ logger.debug("Timeout waiting for response from %s", self.host)
self.remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 4252e5ab5c..bf07951027 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -18,7 +18,7 @@ from synapse.util.frozenutils import freeze, unfreeze
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
- self.__dict__ = internal_metadata_dict
+ self.__dict__ = dict(internal_metadata_dict)
def get_dict(self):
return dict(self.__dict__)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index a9b1b99a10..9d45bdb892 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -23,14 +23,15 @@ import copy
class EventBuilder(EventBase):
- def __init__(self, key_values={}):
+ def __init__(self, key_values={}, internal_metadata_dict={}):
signatures = copy.deepcopy(key_values.pop("signatures", {}))
unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
super(EventBuilder, self).__init__(
key_values,
signatures=signatures,
- unsigned=unsigned
+ unsigned=unsigned,
+ internal_metadata_dict=internal_metadata_dict,
)
def build(self):
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 6bbba8d6ba..7e98bdef28 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -20,3 +20,4 @@ class EventContext(object):
self.current_state = current_state
self.auth_events = auth_events
self.state_group = None
+ self.rejected = False
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index e391aca4cc..1aa952150e 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -45,12 +45,14 @@ def prune_event(event):
"membership",
]
+ event_dict = event.get_dict()
+
new_content = {}
def add_fields(*fields):
for field in fields:
if field in event.content:
- new_content[field] = event.content[field]
+ new_content[field] = event_dict["content"][field]
if event_type == EventTypes.Member:
add_fields("membership")
@@ -75,7 +77,7 @@ def prune_event(event):
allowed_fields = {
k: v
- for k, v in event.get_dict().items()
+ for k, v in event_dict.items()
if k in allowed_keys
}
@@ -86,10 +88,53 @@ def prune_event(event):
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
- return type(event)(allowed_fields)
+ return type(event)(
+ allowed_fields,
+ internal_metadata_dict=event.internal_metadata.get_dict()
+ )
+
+
+def format_event_raw(d):
+ return d
-def serialize_event(e, time_now_ms, client_event=True):
+def format_event_for_client_v1(d):
+ d["user_id"] = d.pop("sender", None)
+
+ move_keys = ("age", "redacted_because", "replaces_state", "prev_content")
+ for key in move_keys:
+ if key in d["unsigned"]:
+ d[key] = d["unsigned"][key]
+
+ drop_keys = (
+ "auth_events", "prev_events", "hashes", "signatures", "depth",
+ "unsigned", "origin", "prev_state"
+ )
+ for key in drop_keys:
+ d.pop(key, None)
+ return d
+
+
+def format_event_for_client_v2(d):
+ drop_keys = (
+ "auth_events", "prev_events", "hashes", "signatures", "depth",
+ "origin", "prev_state",
+ )
+ for key in drop_keys:
+ d.pop(key, None)
+ return d
+
+
+def format_event_for_client_v2_without_event_id(d):
+ d = format_event_for_client_v2(d)
+ d.pop("room_id", None)
+ d.pop("event_id", None)
+ return d
+
+
+def serialize_event(e, time_now_ms, as_client_event=True,
+ event_format=format_event_for_client_v1,
+ token_id=None):
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
@@ -99,43 +144,22 @@ def serialize_event(e, time_now_ms, client_event=True):
# Should this strip out None's?
d = {k: v for k, v in e.get_dict().items()}
- if not client_event:
- # set the age and keep all other keys
- if "age_ts" in d["unsigned"]:
- d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
- return d
-
if "age_ts" in d["unsigned"]:
- d["age"] = time_now_ms - d["unsigned"]["age_ts"]
+ d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"]
del d["unsigned"]["age_ts"]
- d["user_id"] = d.pop("sender", None)
-
if "redacted_because" in e.unsigned:
- d["redacted_because"] = serialize_event(
+ d["unsigned"]["redacted_because"] = serialize_event(
e.unsigned["redacted_because"], time_now_ms
)
- del d["unsigned"]["redacted_because"]
+ if token_id is not None:
+ if token_id == getattr(e.internal_metadata, "token_id", None):
+ txn_id = getattr(e.internal_metadata, "txn_id", None)
+ if txn_id is not None:
+ d["unsigned"]["transaction_id"] = txn_id
- if "redacted_by" in e.unsigned:
- d["redacted_by"] = e.unsigned["redacted_by"]
- del d["unsigned"]["redacted_by"]
-
- if "replaces_state" in e.unsigned:
- d["replaces_state"] = e.unsigned["replaces_state"]
- del d["unsigned"]["replaces_state"]
-
- if "prev_content" in e.unsigned:
- d["prev_content"] = e.unsigned["prev_content"]
- del d["unsigned"]["prev_content"]
-
- del d["auth_events"]
- del d["prev_events"]
- del d["hashes"]
- del d["signatures"]
- d.pop("depth", None)
- d.pop("unsigned", None)
- d.pop("origin", None)
-
- return d
+ if as_client_event:
+ return event_format(d)
+ else:
+ return d
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
new file mode 100644
index 0000000000..e1539bd0e0
--- /dev/null
+++ b/synapse/federation/federation_client.py
@@ -0,0 +1,409 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from .units import Edu
+
+from synapse.util.logutils import log_function
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from syutil.jsonutil import encode_canonical_json
+
+from synapse.crypto.event_signing import check_event_content_hash
+
+from synapse.api.errors import SynapseError
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class FederationClient(object):
+ @log_function
+ def send_pdu(self, pdu, destinations):
+ """Informs the replication layer about a new PDU generated within the
+ home server that should be transmitted to others.
+
+ TODO: Figure out when we should actually resolve the deferred.
+
+ Args:
+ pdu (Pdu): The new Pdu.
+
+ Returns:
+ Deferred: Completes when we have successfully processed the PDU
+ and replicated it to any interested remote home servers.
+ """
+ order = self._order
+ self._order += 1
+
+ logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
+
+ # TODO, add errback, etc.
+ self._transaction_queue.enqueue_pdu(pdu, destinations, order)
+
+ logger.debug(
+ "[%s] transaction_layer.enqueue_pdu... done",
+ pdu.event_id
+ )
+
+ @log_function
+ def send_edu(self, destination, edu_type, content):
+ edu = Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type=edu_type,
+ content=content,
+ )
+
+ # TODO, add errback, etc.
+ self._transaction_queue.enqueue_edu(edu)
+ return defer.succeed(None)
+
+ @log_function
+ def send_failure(self, failure, destination):
+ self._transaction_queue.enqueue_failure(failure, destination)
+ return defer.succeed(None)
+
+ @log_function
+ def make_query(self, destination, query_type, args,
+ retry_on_dns_fail=True):
+ """Sends a federation Query to a remote homeserver of the given type
+ and arguments.
+
+ Args:
+ destination (str): Domain name of the remote homeserver
+ query_type (str): Category of the query type; should match the
+ handler name used in register_query_handler().
+ args (dict): Mapping of strings to strings containing the details
+ of the query request.
+
+ Returns:
+ a Deferred which will eventually yield a JSON object from the
+ response
+ """
+ return self.transport_layer.make_query(
+ destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def backfill(self, dest, context, limit, extremities):
+ """Requests some more historic PDUs for the given context from the
+ given destination server.
+
+ Args:
+ dest (str): The remote home server to ask.
+ context (str): The context to backfill.
+ limit (int): The maximum number of PDUs to return.
+ extremities (list): List of PDU id and origins of the first pdus
+ we have seen from the context
+
+ Returns:
+ Deferred: Results in the received PDUs.
+ """
+ logger.debug("backfill extrem=%s", extremities)
+
+ # If there are no extremeties then we've (probably) reached the start.
+ if not extremities:
+ return
+
+ transaction_data = yield self.transport_layer.backfill(
+ dest, context, extremities, limit)
+
+ logger.debug("backfill transaction_data=%s", repr(transaction_data))
+
+ pdus = [
+ self.event_from_pdu_json(p, outlier=False)
+ for p in transaction_data["pdus"]
+ ]
+
+ for i, pdu in enumerate(pdus):
+ pdus[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ defer.returnValue(pdus)
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_pdu(self, destinations, event_id, outlier=False):
+ """Requests the PDU with given origin and ID from the remote home
+ servers.
+
+ Will attempt to get the PDU from each destination in the list until
+ one succeeds.
+
+ This will persist the PDU locally upon receipt.
+
+ Args:
+ destinations (list): Which home servers to query
+ pdu_origin (str): The home server that originally sent the pdu.
+ event_id (str)
+ outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
+ it's from an arbitary point in the context as opposed to part
+ of the current block of PDUs. Defaults to `False`
+
+ Returns:
+ Deferred: Results in the requested PDU.
+ """
+
+ # TODO: Rate limit the number of times we try and get the same event.
+
+ pdu = None
+ for destination in destinations:
+ try:
+ transaction_data = yield self.transport_layer.get_event(
+ destination, event_id
+ )
+
+ logger.debug("transaction_data %r", transaction_data)
+
+ pdu_list = [
+ self.event_from_pdu_json(p, outlier=outlier)
+ for p in transaction_data["pdus"]
+ ]
+
+ if pdu_list:
+ pdu = pdu_list[0]
+
+ # Check signatures are correct.
+ pdu = yield self._check_sigs_and_hash(pdu)
+
+ break
+
+ except Exception as e:
+ logger.info(
+ "Failed to get PDU %s from %s because %s",
+ event_id, destination, e,
+ )
+ continue
+
+ defer.returnValue(pdu)
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_state_for_room(self, destination, room_id, event_id):
+ """Requests all of the `current` state PDUs for a given room from
+ a remote home server.
+
+ Args:
+ destination (str): The remote homeserver to query for the state.
+ room_id (str): The id of the room we're interested in.
+ event_id (str): The id of the event we want the state at.
+
+ Returns:
+ Deferred: Results in a list of PDUs.
+ """
+
+ result = yield self.transport_layer.get_room_state(
+ destination, room_id, event_id=event_id,
+ )
+
+ pdus = [
+ self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+ ]
+
+ auth_chain = [
+ self.event_from_pdu_json(p, outlier=True)
+ for p in result.get("auth_chain", [])
+ ]
+
+ for i, pdu in enumerate(pdus):
+ pdus[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ for i, pdu in enumerate(auth_chain):
+ auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ defer.returnValue((pdus, auth_chain))
+
+ @defer.inlineCallbacks
+ @log_function
+ def get_event_auth(self, destination, room_id, event_id):
+ res = yield self.transport_layer.get_event_auth(
+ destination, room_id, event_id,
+ )
+
+ auth_chain = [
+ self.event_from_pdu_json(p, outlier=True)
+ for p in res["auth_chain"]
+ ]
+
+ for i, pdu in enumerate(auth_chain):
+ auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ auth_chain.sort(key=lambda e: e.depth)
+
+ defer.returnValue(auth_chain)
+
+ @defer.inlineCallbacks
+ def make_join(self, destination, room_id, user_id):
+ ret = yield self.transport_layer.make_join(
+ destination, room_id, user_id
+ )
+
+ pdu_dict = ret["event"]
+
+ logger.debug("Got response to make_join: %s", pdu_dict)
+
+ defer.returnValue(self.event_from_pdu_json(pdu_dict))
+
+ @defer.inlineCallbacks
+ def send_join(self, destination, pdu):
+ time_now = self._clock.time_msec()
+ _, content = yield self.transport_layer.send_join(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ logger.debug("Got content: %s", content)
+
+ state = [
+ self.event_from_pdu_json(p, outlier=True)
+ for p in content.get("state", [])
+ ]
+
+ auth_chain = [
+ self.event_from_pdu_json(p, outlier=True)
+ for p in content.get("auth_chain", [])
+ ]
+
+ for i, pdu in enumerate(state):
+ state[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ for i, pdu in enumerate(auth_chain):
+ auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ auth_chain.sort(key=lambda e: e.depth)
+
+ defer.returnValue({
+ "state": state,
+ "auth_chain": auth_chain,
+ })
+
+ @defer.inlineCallbacks
+ def send_invite(self, destination, room_id, event_id, pdu):
+ time_now = self._clock.time_msec()
+ code, content = yield self.transport_layer.send_invite(
+ destination=destination,
+ room_id=room_id,
+ event_id=event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ pdu_dict = content["event"]
+
+ logger.debug("Got response to send_invite: %s", pdu_dict)
+
+ pdu = self.event_from_pdu_json(pdu_dict)
+
+ # Check signatures are correct.
+ pdu = yield self._check_sigs_and_hash(pdu)
+
+ # FIXME: We should handle signature failures more gracefully.
+
+ defer.returnValue(pdu)
+
+ @defer.inlineCallbacks
+ def query_auth(self, destination, room_id, event_id, local_auth):
+ """
+ Params:
+ destination (str)
+ event_it (str)
+ local_auth (list)
+ """
+ time_now = self._clock.time_msec()
+
+ send_content = {
+ "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
+ }
+
+ code, content = yield self.transport_layer.send_query_auth(
+ destination=destination,
+ room_id=room_id,
+ event_id=event_id,
+ content=send_content,
+ )
+
+ auth_chain = [
+ (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+ for e in content["auth_chain"]
+ ]
+
+ ret = {
+ "auth_chain": auth_chain,
+ "rejects": content.get("rejects", []),
+ "missing": content.get("missing", []),
+ }
+
+ defer.returnValue(ret)
+
+ def event_from_pdu_json(self, pdu_json, outlier=False):
+ event = FrozenEvent(
+ pdu_json
+ )
+
+ event.internal_metadata.outlier = outlier
+
+ return event
+
+ @defer.inlineCallbacks
+ def _check_sigs_and_hash(self, pdu):
+ """Throws a SynapseError if the PDU does not have the correct
+ signatures.
+
+ Returns:
+ FrozenEvent: Either the given event or it redacted if it failed the
+ content hash check.
+ """
+ # Check signatures are correct.
+ redacted_event = prune_event(pdu)
+ redacted_pdu_json = redacted_event.get_pdu_json()
+
+ try:
+ yield self.keyring.verify_json_for_server(
+ pdu.origin, redacted_pdu_json
+ )
+ except SynapseError:
+ logger.warn(
+ "Signature check failed for %s redacted to %s",
+ encode_canonical_json(pdu.get_pdu_json()),
+ encode_canonical_json(redacted_pdu_json),
+ )
+ raise
+
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s, %s",
+ pdu.event_id, encode_canonical_json(pdu.get_dict())
+ )
+ defer.returnValue(redacted_event)
+
+ defer.returnValue(pdu)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
new file mode 100644
index 0000000000..5fbd8b19de
--- /dev/null
+++ b/synapse/federation/federation_server.py
@@ -0,0 +1,462 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from .units import Transaction, Edu
+
+from synapse.util.logutils import log_function
+from synapse.util.logcontext import PreserveLoggingContext
+from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from syutil.jsonutil import encode_canonical_json
+
+from synapse.crypto.event_signing import check_event_content_hash
+
+from synapse.api.errors import FederationError, SynapseError
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class FederationServer(object):
+ def set_handler(self, handler):
+ """Sets the handler that the replication layer will use to communicate
+ receipt of new PDUs from other home servers. The required methods are
+ documented on :py:class:`.ReplicationHandler`.
+ """
+ self.handler = handler
+
+ def register_edu_handler(self, edu_type, handler):
+ if edu_type in self.edu_handlers:
+ raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+
+ self.edu_handlers[edu_type] = handler
+
+ def register_query_handler(self, query_type, handler):
+ """Sets the handler callable that will be used to handle an incoming
+ federation Query of the given type.
+
+ Args:
+ query_type (str): Category name of the query, which should match
+ the string used by make_query.
+ handler (callable): Invoked to handle incoming queries of this type
+
+ handler is invoked as:
+ result = handler(args)
+
+ where 'args' is a dict mapping strings to strings of the query
+ arguments. It should return a Deferred that will eventually yield an
+ object to encode as JSON.
+ """
+ if query_type in self.query_handlers:
+ raise KeyError(
+ "Already have a Query handler for %s" % (query_type,)
+ )
+
+ self.query_handlers[query_type] = handler
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_backfill_request(self, origin, room_id, versions, limit):
+ pdus = yield self.handler.on_backfill_request(
+ origin, room_id, versions, limit
+ )
+
+ defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_incoming_transaction(self, transaction_data):
+ transaction = Transaction(**transaction_data)
+
+ for p in transaction.pdus:
+ if "unsigned" in p:
+ unsigned = p["unsigned"]
+ if "age" in unsigned:
+ p["age"] = unsigned["age"]
+ if "age" in p:
+ p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
+ del p["age"]
+
+ pdu_list = [
+ self.event_from_pdu_json(p) for p in transaction.pdus
+ ]
+
+ logger.debug("[%s] Got transaction", transaction.transaction_id)
+
+ response = yield self.transaction_actions.have_responded(transaction)
+
+ if response:
+ logger.debug(
+ "[%s] We've already responed to this request",
+ transaction.transaction_id
+ )
+ defer.returnValue(response)
+ return
+
+ logger.debug("[%s] Transaction is new", transaction.transaction_id)
+
+ with PreserveLoggingContext():
+ dl = []
+ for pdu in pdu_list:
+ dl.append(self._handle_new_pdu(transaction.origin, pdu))
+
+ if hasattr(transaction, "edus"):
+ for edu in [Edu(**x) for x in transaction.edus]:
+ self.received_edu(
+ transaction.origin,
+ edu.edu_type,
+ edu.content
+ )
+
+ results = yield defer.DeferredList(dl)
+
+ ret = []
+ for r in results:
+ if r[0]:
+ ret.append({})
+ else:
+ logger.exception(r[1])
+ ret.append({"error": str(r[1])})
+
+ logger.debug("Returning: %s", str(ret))
+
+ yield self.transaction_actions.set_response(
+ transaction,
+ 200, response
+ )
+ defer.returnValue((200, response))
+
+ def received_edu(self, origin, edu_type, content):
+ if edu_type in self.edu_handlers:
+ self.edu_handlers[edu_type](origin, content)
+ else:
+ logger.warn("Received EDU of type %s with no handler", edu_type)
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_context_state_request(self, origin, room_id, event_id):
+ if event_id:
+ pdus = yield self.handler.get_state_for_pdu(
+ origin, room_id, event_id,
+ )
+ auth_chain = yield self.store.get_auth_chain(
+ [pdu.event_id for pdu in pdus]
+ )
+ else:
+ raise NotImplementedError("Specify an event")
+
+ defer.returnValue((200, {
+ "pdus": [pdu.get_pdu_json() for pdu in pdus],
+ "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
+ }))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_pdu_request(self, origin, event_id):
+ pdu = yield self._get_persisted_pdu(origin, event_id)
+
+ if pdu:
+ defer.returnValue(
+ (200, self._transaction_from_pdus([pdu]).get_dict())
+ )
+ else:
+ defer.returnValue((404, ""))
+
+ @defer.inlineCallbacks
+ @log_function
+ def on_pull_request(self, origin, versions):
+ raise NotImplementedError("Pull transactions not implemented")
+
+ @defer.inlineCallbacks
+ def on_query_request(self, query_type, args):
+ if query_type in self.query_handlers:
+ response = yield self.query_handlers[query_type](args)
+ defer.returnValue((200, response))
+ else:
+ defer.returnValue(
+ (404, "No handler for Query type '%s'" % (query_type,))
+ )
+
+ @defer.inlineCallbacks
+ def on_make_join_request(self, room_id, user_id):
+ pdu = yield self.handler.on_make_join_request(room_id, user_id)
+ time_now = self._clock.time_msec()
+ defer.returnValue({"event": pdu.get_pdu_json(time_now)})
+
+ @defer.inlineCallbacks
+ def on_invite_request(self, origin, content):
+ pdu = self.event_from_pdu_json(content)
+ ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+ time_now = self._clock.time_msec()
+ defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
+
+ @defer.inlineCallbacks
+ def on_send_join_request(self, origin, content):
+ logger.debug("on_send_join_request: content: %s", content)
+ pdu = self.event_from_pdu_json(content)
+ logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
+ res_pdus = yield self.handler.on_send_join_request(origin, pdu)
+ time_now = self._clock.time_msec()
+ defer.returnValue((200, {
+ "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+ "auth_chain": [
+ p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
+ ],
+ }))
+
+ @defer.inlineCallbacks
+ def on_event_auth(self, origin, room_id, event_id):
+ time_now = self._clock.time_msec()
+ auth_pdus = yield self.handler.on_event_auth(event_id)
+ defer.returnValue((200, {
+ "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
+ }))
+
+ @defer.inlineCallbacks
+ def on_query_auth_request(self, origin, content, event_id):
+ """
+ Content is a dict with keys::
+ auth_chain (list): A list of events that give the auth chain.
+ missing (list): A list of event_ids indicating what the other
+ side (`origin`) think we're missing.
+ rejects (dict): A mapping from event_id to a 2-tuple of reason
+ string and a proof (or None) of why the event was rejected.
+ The keys of this dict give the list of events the `origin` has
+ rejected.
+
+ Args:
+ origin (str)
+ content (dict)
+ event_id (str)
+
+ Returns:
+ Deferred: Results in `dict` with the same format as `content`
+ """
+ auth_chain = [
+ (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+ for e in content["auth_chain"]
+ ]
+
+ missing = [
+ (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+ for e in content.get("missing", [])
+ ]
+
+ ret = yield self.handler.on_query_auth(
+ origin, event_id, auth_chain, content.get("rejects", []), missing
+ )
+
+ time_now = self._clock.time_msec()
+ send_content = {
+ "auth_chain": [
+ e.get_pdu_json(time_now)
+ for e in ret["auth_chain"]
+ ],
+ "rejects": ret.get("rejects", []),
+ "missing": ret.get("missing", []),
+ }
+
+ defer.returnValue(
+ (200, send_content)
+ )
+
+ @log_function
+ def _get_persisted_pdu(self, origin, event_id, do_auth=True):
+ """ Get a PDU from the database with given origin and id.
+
+ Returns:
+ Deferred: Results in a `Pdu`.
+ """
+ return self.handler.get_persisted_pdu(
+ origin, event_id, do_auth=do_auth
+ )
+
+ def _transaction_from_pdus(self, pdu_list):
+ """Returns a new Transaction containing the given PDUs suitable for
+ transmission.
+ """
+ time_now = self._clock.time_msec()
+ pdus = [p.get_pdu_json(time_now) for p in pdu_list]
+ return Transaction(
+ origin=self.server_name,
+ pdus=pdus,
+ origin_server_ts=int(time_now),
+ destination=None,
+ )
+
+ @defer.inlineCallbacks
+ @log_function
+ def _handle_new_pdu(self, origin, pdu, max_recursion=10):
+ # We reprocess pdus when we have seen them only as outliers
+ existing = yield self._get_persisted_pdu(
+ origin, pdu.event_id, do_auth=False
+ )
+
+ # FIXME: Currently we fetch an event again when we already have it
+ # if it has been marked as an outlier.
+
+ already_seen = (
+ existing and (
+ not existing.internal_metadata.is_outlier()
+ or pdu.internal_metadata.is_outlier()
+ )
+ )
+ if already_seen:
+ logger.debug("Already seen pdu %s", pdu.event_id)
+ defer.returnValue({})
+ return
+
+ # Check signature.
+ try:
+ pdu = yield self._check_sigs_and_hash(pdu)
+ except SynapseError as e:
+ raise FederationError(
+ "ERROR",
+ e.code,
+ e.msg,
+ affected=pdu.event_id,
+ )
+
+ state = None
+
+ auth_chain = []
+
+ have_seen = yield self.store.have_events(
+ [ev for ev, _ in pdu.prev_events]
+ )
+
+ fetch_state = False
+
+ # Get missing pdus if necessary.
+ if not pdu.internal_metadata.is_outlier():
+ # We only backfill backwards to the min depth.
+ min_depth = yield self.handler.get_min_depth_for_context(
+ pdu.room_id
+ )
+
+ logger.debug(
+ "_handle_new_pdu min_depth for %s: %d",
+ pdu.room_id, min_depth
+ )
+
+ if min_depth and pdu.depth > min_depth and max_recursion > 0:
+ for event_id, hashes in pdu.prev_events:
+ if event_id not in have_seen:
+ logger.debug(
+ "_handle_new_pdu requesting pdu %s",
+ event_id
+ )
+
+ try:
+ new_pdu = yield self.federation_client.get_pdu(
+ [origin, pdu.origin],
+ event_id=event_id,
+ )
+
+ if new_pdu:
+ yield self._handle_new_pdu(
+ origin,
+ new_pdu,
+ max_recursion=max_recursion-1
+ )
+
+ logger.debug("Processed pdu %s", event_id)
+ else:
+ logger.warn("Failed to get PDU %s", event_id)
+ fetch_state = True
+ except:
+ # TODO(erikj): Do some more intelligent retries.
+ logger.exception("Failed to get PDU")
+ fetch_state = True
+ else:
+ prevs = {e_id for e_id, _ in pdu.prev_events}
+ seen = set(have_seen.keys())
+ if prevs - seen:
+ fetch_state = True
+ else:
+ fetch_state = True
+
+ if fetch_state:
+ # We need to get the state at this event, since we haven't
+ # processed all the prev events.
+ logger.debug(
+ "_handle_new_pdu getting state for %s",
+ pdu.room_id
+ )
+ state, auth_chain = yield self.get_state_for_room(
+ origin, pdu.room_id, pdu.event_id,
+ )
+
+ ret = yield self.handler.on_receive_pdu(
+ origin,
+ pdu,
+ backfilled=False,
+ state=state,
+ auth_chain=auth_chain,
+ )
+
+ defer.returnValue(ret)
+
+ def __str__(self):
+ return "<ReplicationLayer(%s)>" % self.server_name
+
+ def event_from_pdu_json(self, pdu_json, outlier=False):
+ event = FrozenEvent(
+ pdu_json
+ )
+
+ event.internal_metadata.outlier = outlier
+
+ return event
+
+ @defer.inlineCallbacks
+ def _check_sigs_and_hash(self, pdu):
+ """Throws a SynapseError if the PDU does not have the correct
+ signatures.
+
+ Returns:
+ FrozenEvent: Either the given event or it redacted if it failed the
+ content hash check.
+ """
+ # Check signatures are correct.
+ redacted_event = prune_event(pdu)
+ redacted_pdu_json = redacted_event.get_pdu_json()
+
+ try:
+ yield self.keyring.verify_json_for_server(
+ pdu.origin, redacted_pdu_json
+ )
+ except SynapseError:
+ logger.warn(
+ "Signature check failed for %s redacted to %s",
+ encode_canonical_json(pdu.get_pdu_json()),
+ encode_canonical_json(redacted_pdu_json),
+ )
+ raise
+
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s, %s",
+ pdu.event_id, encode_canonical_json(pdu.get_dict())
+ )
+ defer.returnValue(redacted_event)
+
+ defer.returnValue(pdu)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 6620532a60..e442c6c5d5 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -17,23 +17,20 @@
a given transport.
"""
-from twisted.internet import defer
+from .federation_client import FederationClient
+from .federation_server import FederationServer
-from .units import Transaction, Edu
+from .transaction_queue import TransactionQueue
from .persistence import TransactionActions
-from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
-from synapse.events import FrozenEvent
-
import logging
logger = logging.getLogger(__name__)
-class ReplicationLayer(object):
+class ReplicationLayer(FederationClient, FederationServer):
"""This layer is responsible for replicating with remote home servers over
the given transport. I.e., does the sending and receiving of PDUs to
remote home servers.
@@ -54,898 +51,26 @@ class ReplicationLayer(object):
def __init__(self, hs, transport_layer):
self.server_name = hs.hostname
+ self.keyring = hs.get_keyring()
+
self.transport_layer = transport_layer
self.transport_layer.register_received_handler(self)
self.transport_layer.register_request_handler(self)
- self.store = hs.get_datastore()
- # self.pdu_actions = PduActions(self.store)
- self.transaction_actions = TransactionActions(self.store)
+ self.federation_client = self
- self._transaction_queue = _TransactionQueue(
- hs, self.transaction_actions, transport_layer
- )
+ self.store = hs.get_datastore()
self.handler = None
self.edu_handlers = {}
self.query_handlers = {}
- self._order = 0
-
self._clock = hs.get_clock()
- self.event_builder_factory = hs.get_event_builder_factory()
-
- def set_handler(self, handler):
- """Sets the handler that the replication layer will use to communicate
- receipt of new PDUs from other home servers. The required methods are
- documented on :py:class:`.ReplicationHandler`.
- """
- self.handler = handler
-
- def register_edu_handler(self, edu_type, handler):
- if edu_type in self.edu_handlers:
- raise KeyError("Already have an EDU handler for %s" % (edu_type,))
-
- self.edu_handlers[edu_type] = handler
-
- def register_query_handler(self, query_type, handler):
- """Sets the handler callable that will be used to handle an incoming
- federation Query of the given type.
-
- Args:
- query_type (str): Category name of the query, which should match
- the string used by make_query.
- handler (callable): Invoked to handle incoming queries of this type
-
- handler is invoked as:
- result = handler(args)
-
- where 'args' is a dict mapping strings to strings of the query
- arguments. It should return a Deferred that will eventually yield an
- object to encode as JSON.
- """
- if query_type in self.query_handlers:
- raise KeyError(
- "Already have a Query handler for %s" % (query_type,)
- )
-
- self.query_handlers[query_type] = handler
-
- @log_function
- def send_pdu(self, pdu, destinations):
- """Informs the replication layer about a new PDU generated within the
- home server that should be transmitted to others.
-
- TODO: Figure out when we should actually resolve the deferred.
-
- Args:
- pdu (Pdu): The new Pdu.
-
- Returns:
- Deferred: Completes when we have successfully processed the PDU
- and replicated it to any interested remote home servers.
- """
- order = self._order
- self._order += 1
-
- logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
-
- # TODO, add errback, etc.
- self._transaction_queue.enqueue_pdu(pdu, destinations, order)
-
- logger.debug(
- "[%s] transaction_layer.enqueue_pdu... done",
- pdu.event_id
- )
-
- @log_function
- def send_edu(self, destination, edu_type, content):
- edu = Edu(
- origin=self.server_name,
- destination=destination,
- edu_type=edu_type,
- content=content,
- )
-
- # TODO, add errback, etc.
- self._transaction_queue.enqueue_edu(edu)
- return defer.succeed(None)
-
- @log_function
- def send_failure(self, failure, destination):
- self._transaction_queue.enqueue_failure(failure, destination)
- return defer.succeed(None)
-
- @log_function
- def make_query(self, destination, query_type, args,
- retry_on_dns_fail=True):
- """Sends a federation Query to a remote homeserver of the given type
- and arguments.
-
- Args:
- destination (str): Domain name of the remote homeserver
- query_type (str): Category of the query type; should match the
- handler name used in register_query_handler().
- args (dict): Mapping of strings to strings containing the details
- of the query request.
-
- Returns:
- a Deferred which will eventually yield a JSON object from the
- response
- """
- return self.transport_layer.make_query(
- destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
- )
-
- @defer.inlineCallbacks
- @log_function
- def backfill(self, dest, context, limit, extremities):
- """Requests some more historic PDUs for the given context from the
- given destination server.
-
- Args:
- dest (str): The remote home server to ask.
- context (str): The context to backfill.
- limit (int): The maximum number of PDUs to return.
- extremities (list): List of PDU id and origins of the first pdus
- we have seen from the context
-
- Returns:
- Deferred: Results in the received PDUs.
- """
- logger.debug("backfill extrem=%s", extremities)
-
- # If there are no extremeties then we've (probably) reached the start.
- if not extremities:
- return
-
- transaction_data = yield self.transport_layer.backfill(
- dest, context, extremities, limit)
-
- logger.debug("backfill transaction_data=%s", repr(transaction_data))
-
- transaction = Transaction(**transaction_data)
-
- pdus = [
- self.event_from_pdu_json(p, outlier=False)
- for p in transaction.pdus
- ]
- for pdu in pdus:
- yield self._handle_new_pdu(dest, pdu, backfilled=True)
-
- defer.returnValue(pdus)
-
- @defer.inlineCallbacks
- @log_function
- def get_pdu(self, destination, event_id, outlier=False):
- """Requests the PDU with given origin and ID from the remote home
- server.
-
- This will persist the PDU locally upon receipt.
-
- Args:
- destination (str): Which home server to query
- pdu_origin (str): The home server that originally sent the pdu.
- event_id (str)
- outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
- it's from an arbitary point in the context as opposed to part
- of the current block of PDUs. Defaults to `False`
-
- Returns:
- Deferred: Results in the requested PDU.
- """
-
- transaction_data = yield self.transport_layer.get_event(
- destination, event_id
- )
-
- transaction = Transaction(**transaction_data)
-
- pdu_list = [
- self.event_from_pdu_json(p, outlier=outlier)
- for p in transaction.pdus
- ]
-
- pdu = None
- if pdu_list:
- pdu = pdu_list[0]
- yield self._handle_new_pdu(destination, pdu)
-
- defer.returnValue(pdu)
-
- @defer.inlineCallbacks
- @log_function
- def get_state_for_room(self, destination, room_id, event_id):
- """Requests all of the `current` state PDUs for a given room from
- a remote home server.
-
- Args:
- destination (str): The remote homeserver to query for the state.
- room_id (str): The id of the room we're interested in.
- event_id (str): The id of the event we want the state at.
-
- Returns:
- Deferred: Results in a list of PDUs.
- """
-
- result = yield self.transport_layer.get_room_state(
- destination, room_id, event_id=event_id,
- )
-
- pdus = [
- self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
- ]
-
- auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
- for p in result.get("auth_chain", [])
- ]
-
- defer.returnValue((pdus, auth_chain))
-
- @defer.inlineCallbacks
- @log_function
- def get_event_auth(self, destination, room_id, event_id):
- res = yield self.transport_layer.get_event_auth(
- destination, room_id, event_id,
- )
-
- auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
- for p in res["auth_chain"]
- ]
-
- auth_chain.sort(key=lambda e: e.depth)
-
- defer.returnValue(auth_chain)
-
- @defer.inlineCallbacks
- @log_function
- def on_backfill_request(self, origin, room_id, versions, limit):
- pdus = yield self.handler.on_backfill_request(
- origin, room_id, versions, limit
- )
-
- defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict()))
-
- @defer.inlineCallbacks
- @log_function
- def on_incoming_transaction(self, transaction_data):
- transaction = Transaction(**transaction_data)
-
- for p in transaction.pdus:
- if "unsigned" in p:
- unsigned = p["unsigned"]
- if "age" in unsigned:
- p["age"] = unsigned["age"]
- if "age" in p:
- p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
- del p["age"]
-
- pdu_list = [
- self.event_from_pdu_json(p) for p in transaction.pdus
- ]
-
- logger.debug("[%s] Got transaction", transaction.transaction_id)
-
- response = yield self.transaction_actions.have_responded(transaction)
-
- if response:
- logger.debug("[%s] We've already responed to this request",
- transaction.transaction_id)
- defer.returnValue(response)
- return
-
- logger.debug("[%s] Transaction is new", transaction.transaction_id)
-
- with PreserveLoggingContext():
- dl = []
- for pdu in pdu_list:
- dl.append(self._handle_new_pdu(transaction.origin, pdu))
-
- if hasattr(transaction, "edus"):
- for edu in [Edu(**x) for x in transaction.edus]:
- self.received_edu(
- transaction.origin,
- edu.edu_type,
- edu.content
- )
-
- results = yield defer.DeferredList(dl)
-
- ret = []
- for r in results:
- if r[0]:
- ret.append({})
- else:
- logger.exception(r[1])
- ret.append({"error": str(r[1])})
-
- logger.debug("Returning: %s", str(ret))
-
- yield self.transaction_actions.set_response(
- transaction,
- 200, response
- )
- defer.returnValue((200, response))
-
- def received_edu(self, origin, edu_type, content):
- if edu_type in self.edu_handlers:
- self.edu_handlers[edu_type](origin, content)
- else:
- logger.warn("Received EDU of type %s with no handler", edu_type)
-
- @defer.inlineCallbacks
- @log_function
- def on_context_state_request(self, origin, room_id, event_id):
- if event_id:
- pdus = yield self.handler.get_state_for_pdu(
- origin, room_id, event_id,
- )
- auth_chain = yield self.store.get_auth_chain(
- [pdu.event_id for pdu in pdus]
- )
- else:
- raise NotImplementedError("Specify an event")
-
- defer.returnValue((200, {
- "pdus": [pdu.get_pdu_json() for pdu in pdus],
- "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
- }))
-
- @defer.inlineCallbacks
- @log_function
- def on_pdu_request(self, origin, event_id):
- pdu = yield self._get_persisted_pdu(origin, event_id)
-
- if pdu:
- defer.returnValue(
- (200, self._transaction_from_pdus([pdu]).get_dict())
- )
- else:
- defer.returnValue((404, ""))
-
- @defer.inlineCallbacks
- @log_function
- def on_pull_request(self, origin, versions):
- raise NotImplementedError("Pull transactions not implemented")
-
- @defer.inlineCallbacks
- def on_query_request(self, query_type, args):
- if query_type in self.query_handlers:
- response = yield self.query_handlers[query_type](args)
- defer.returnValue((200, response))
- else:
- defer.returnValue(
- (404, "No handler for Query type '%s'" % (query_type,))
- )
-
- @defer.inlineCallbacks
- def on_make_join_request(self, room_id, user_id):
- pdu = yield self.handler.on_make_join_request(room_id, user_id)
- time_now = self._clock.time_msec()
- defer.returnValue({"event": pdu.get_pdu_json(time_now)})
-
- @defer.inlineCallbacks
- def on_invite_request(self, origin, content):
- pdu = self.event_from_pdu_json(content)
- ret_pdu = yield self.handler.on_invite_request(origin, pdu)
- time_now = self._clock.time_msec()
- defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
-
- @defer.inlineCallbacks
- def on_send_join_request(self, origin, content):
- logger.debug("on_send_join_request: content: %s", content)
- pdu = self.event_from_pdu_json(content)
- logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
- res_pdus = yield self.handler.on_send_join_request(origin, pdu)
- time_now = self._clock.time_msec()
- defer.returnValue((200, {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [
- p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
- ],
- }))
-
- @defer.inlineCallbacks
- def on_event_auth(self, origin, room_id, event_id):
- time_now = self._clock.time_msec()
- auth_pdus = yield self.handler.on_event_auth(event_id)
- defer.returnValue((200, {
- "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
- }))
-
- @defer.inlineCallbacks
- def make_join(self, destination, room_id, user_id):
- ret = yield self.transport_layer.make_join(
- destination, room_id, user_id
- )
-
- pdu_dict = ret["event"]
-
- logger.debug("Got response to make_join: %s", pdu_dict)
-
- defer.returnValue(self.event_from_pdu_json(pdu_dict))
-
- @defer.inlineCallbacks
- def send_join(self, destination, pdu):
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_join(
- destination=destination,
- room_id=pdu.room_id,
- event_id=pdu.event_id,
- content=pdu.get_pdu_json(time_now),
- )
-
- logger.debug("Got content: %s", content)
-
- state = [
- self.event_from_pdu_json(p, outlier=True)
- for p in content.get("state", [])
- ]
-
- auth_chain = [
- self.event_from_pdu_json(p, outlier=True)
- for p in content.get("auth_chain", [])
- ]
-
- auth_chain.sort(key=lambda e: e.depth)
-
- defer.returnValue({
- "state": state,
- "auth_chain": auth_chain,
- })
-
- @defer.inlineCallbacks
- def send_invite(self, destination, room_id, event_id, pdu):
- time_now = self._clock.time_msec()
- code, content = yield self.transport_layer.send_invite(
- destination=destination,
- room_id=room_id,
- event_id=event_id,
- content=pdu.get_pdu_json(time_now),
- )
-
- pdu_dict = content["event"]
-
- logger.debug("Got response to send_invite: %s", pdu_dict)
-
- defer.returnValue(self.event_from_pdu_json(pdu_dict))
-
- @log_function
- def _get_persisted_pdu(self, origin, event_id, do_auth=True):
- """ Get a PDU from the database with given origin and id.
-
- Returns:
- Deferred: Results in a `Pdu`.
- """
- return self.handler.get_persisted_pdu(
- origin, event_id, do_auth=do_auth
- )
-
- def _transaction_from_pdus(self, pdu_list):
- """Returns a new Transaction containing the given PDUs suitable for
- transmission.
- """
- time_now = self._clock.time_msec()
- pdus = [p.get_pdu_json(time_now) for p in pdu_list]
- return Transaction(
- origin=self.server_name,
- pdus=pdus,
- origin_server_ts=int(time_now),
- destination=None,
- )
-
- @defer.inlineCallbacks
- @log_function
- def _handle_new_pdu(self, origin, pdu, backfilled=False):
- # We reprocess pdus when we have seen them only as outliers
- existing = yield self._get_persisted_pdu(
- origin, pdu.event_id, do_auth=False
- )
-
- already_seen = (
- existing and (
- not existing.internal_metadata.is_outlier()
- or pdu.internal_metadata.is_outlier()
- )
- )
- if already_seen:
- logger.debug("Already seen pdu %s", pdu.event_id)
- defer.returnValue({})
- return
-
- state = None
-
- auth_chain = []
-
- # We need to make sure we have all the auth events.
- # for e_id, _ in pdu.auth_events:
- # exists = yield self._get_persisted_pdu(
- # origin,
- # e_id,
- # do_auth=False
- # )
- #
- # if not exists:
- # try:
- # logger.debug(
- # "_handle_new_pdu fetch missing auth event %s from %s",
- # e_id,
- # origin,
- # )
- #
- # yield self.get_pdu(
- # origin,
- # event_id=e_id,
- # outlier=True,
- # )
- #
- # logger.debug("Processed pdu %s", e_id)
- # except:
- # logger.warn(
- # "Failed to get auth event %s from %s",
- # e_id,
- # origin
- # )
-
- # Get missing pdus if necessary.
- if not pdu.internal_metadata.is_outlier():
- # We only backfill backwards to the min depth.
- min_depth = yield self.handler.get_min_depth_for_context(
- pdu.room_id
- )
-
- logger.debug(
- "_handle_new_pdu min_depth for %s: %d",
- pdu.room_id, min_depth
- )
-
- if min_depth and pdu.depth > min_depth:
- for event_id, hashes in pdu.prev_events:
- exists = yield self._get_persisted_pdu(
- origin,
- event_id,
- do_auth=False
- )
-
- if not exists:
- logger.debug(
- "_handle_new_pdu requesting pdu %s",
- event_id
- )
-
- try:
- yield self.get_pdu(
- origin,
- event_id=event_id,
- )
- logger.debug("Processed pdu %s", event_id)
- except:
- # TODO(erikj): Do some more intelligent retries.
- logger.exception("Failed to get PDU")
- else:
- # We need to get the state at this event, since we have reached
- # a backward extremity edge.
- logger.debug(
- "_handle_new_pdu getting state for %s",
- pdu.room_id
- )
- state, auth_chain = yield self.get_state_for_room(
- origin, pdu.room_id, pdu.event_id,
- )
-
- if not backfilled:
- ret = yield self.handler.on_receive_pdu(
- origin,
- pdu,
- backfilled=backfilled,
- state=state,
- auth_chain=auth_chain,
- )
- else:
- ret = None
-
- # yield self.pdu_actions.mark_as_processed(pdu)
+ self.transaction_actions = TransactionActions(self.store)
+ self._transaction_queue = TransactionQueue(hs, transport_layer)
- defer.returnValue(ret)
+ self._order = 0
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
-
- def event_from_pdu_json(self, pdu_json, outlier=False):
- event = FrozenEvent(
- pdu_json
- )
-
- event.internal_metadata.outlier = outlier
-
- return event
-
-
-class _TransactionQueue(object):
- """This class makes sure we only have one transaction in flight at
- a time for a given destination.
-
- It batches pending PDUs into single transactions.
- """
-
- def __init__(self, hs, transaction_actions, transport_layer):
- self.server_name = hs.hostname
- self.transaction_actions = transaction_actions
- self.transport_layer = transport_layer
-
- self._clock = hs.get_clock()
- self.store = hs.get_datastore()
-
- # Is a mapping from destinations -> deferreds. Used to keep track
- # of which destinations have transactions in flight and when they are
- # done
- self.pending_transactions = {}
-
- # Is a mapping from destination -> list of
- # tuple(pending pdus, deferred, order)
- self.pending_pdus_by_dest = {}
- # destination -> list of tuple(edu, deferred)
- self.pending_edus_by_dest = {}
-
- # destination -> list of tuple(failure, deferred)
- self.pending_failures_by_dest = {}
-
- # HACK to get unique tx id
- self._next_txn_id = int(self._clock.time_msec())
-
- @defer.inlineCallbacks
- @log_function
- def enqueue_pdu(self, pdu, destinations, order):
- # We loop through all destinations to see whether we already have
- # a transaction in progress. If we do, stick it in the pending_pdus
- # table and we'll get back to it later.
-
- destinations = set(destinations)
- destinations.discard(self.server_name)
- destinations.discard("localhost")
-
- logger.debug("Sending to: %s", str(destinations))
-
- if not destinations:
- return
-
- deferreds = []
-
- for destination in destinations:
- deferred = defer.Deferred()
- self.pending_pdus_by_dest.setdefault(destination, []).append(
- (pdu, deferred, order)
- )
-
- def eb(failure):
- if not deferred.called:
- deferred.errback(failure)
- else:
- logger.warn("Failed to send pdu", failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(eb)
-
- deferreds.append(deferred)
-
- yield defer.DeferredList(deferreds)
-
- # NO inlineCallbacks
- def enqueue_edu(self, edu):
- destination = edu.destination
-
- if destination == self.server_name:
- return
-
- deferred = defer.Deferred()
- self.pending_edus_by_dest.setdefault(destination, []).append(
- (edu, deferred)
- )
-
- def eb(failure):
- if not deferred.called:
- deferred.errback(failure)
- else:
- logger.warn("Failed to send edu", failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(eb)
-
- return deferred
-
- @defer.inlineCallbacks
- def enqueue_failure(self, failure, destination):
- deferred = defer.Deferred()
-
- self.pending_failures_by_dest.setdefault(
- destination, []
- ).append(
- (failure, deferred)
- )
-
- yield deferred
-
- @defer.inlineCallbacks
- @log_function
- def _attempt_new_transaction(self, destination):
-
- (retry_last_ts, retry_interval) = (0, 0)
- retry_timings = yield self.store.get_destination_retry_timings(
- destination
- )
- if retry_timings:
- (retry_last_ts, retry_interval) = (
- retry_timings.retry_last_ts, retry_timings.retry_interval
- )
- if retry_last_ts + retry_interval > int(self._clock.time_msec()):
- logger.info(
- "TX [%s] not ready for retry yet - "
- "dropping transaction for now",
- destination,
- )
- return
- else:
- logger.info("TX [%s] is ready for retry", destination)
-
- logger.info("TX [%s] _attempt_new_transaction", destination)
-
- if destination in self.pending_transactions:
- # XXX: pending_transactions can get stuck on by a never-ending
- # request at which point pending_pdus_by_dest just keeps growing.
- # we need application-layer timeouts of some flavour of these
- # requests
- return
-
- # list of (pending_pdu, deferred, order)
- pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
- pending_edus = self.pending_edus_by_dest.pop(destination, [])
- pending_failures = self.pending_failures_by_dest.pop(destination, [])
-
- if pending_pdus:
- logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- destination, len(pending_pdus))
-
- if not pending_pdus and not pending_edus and not pending_failures:
- return
-
- logger.debug(
- "TX [%s] Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
- destination,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures)
- )
-
- # Sort based on the order field
- pending_pdus.sort(key=lambda t: t[2])
-
- pdus = [x[0] for x in pending_pdus]
- edus = [x[0] for x in pending_edus]
- failures = [x[0].get_dict() for x in pending_failures]
- deferreds = [
- x[1]
- for x in pending_pdus + pending_edus + pending_failures
- ]
-
- try:
- self.pending_transactions[destination] = 1
-
- logger.debug("TX [%s] Persisting transaction...", destination)
-
- transaction = Transaction.create_new(
- origin_server_ts=int(self._clock.time_msec()),
- transaction_id=str(self._next_txn_id),
- origin=self.server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- pdu_failures=failures,
- )
-
- self._next_txn_id += 1
-
- yield self.transaction_actions.prepare_to_send(transaction)
-
- logger.debug("TX [%s] Persisted transaction", destination)
- logger.info(
- "TX [%s] Sending transaction [%s]",
- destination,
- transaction.transaction_id,
- )
-
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self._clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- code, response = yield self.transport_layer.send_transaction(
- transaction, json_data_cb
- )
-
- logger.info("TX [%s] got %d response", destination, code)
-
- logger.debug("TX [%s] Sent transaction", destination)
- logger.debug("TX [%s] Marking as delivered...", destination)
-
- yield self.transaction_actions.delivered(
- transaction, code, response
- )
-
- logger.debug("TX [%s] Marked as delivered", destination)
- logger.debug("TX [%s] Yielding to callbacks...", destination)
-
- for deferred in deferreds:
- if code == 200:
- if retry_last_ts:
- # this host is alive! reset retry schedule
- yield self.store.set_destination_retry_timings(
- destination, 0, 0
- )
- deferred.callback(None)
- else:
- self.set_retrying(destination, retry_interval)
- deferred.errback(RuntimeError("Got status %d" % code))
-
- # Ensures we don't continue until all callbacks on that
- # deferred have fired
- try:
- yield deferred
- except:
- pass
-
- logger.debug("TX [%s] Yielded to callbacks", destination)
-
- except Exception as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
-
- self.set_retrying(destination, retry_interval)
-
- for deferred in deferreds:
- if not deferred.called:
- deferred.errback(e)
-
- finally:
- # We want to be *very* sure we delete this after we stop processing
- self.pending_transactions.pop(destination, None)
-
- # Check to see if there is anything else to send.
- self._attempt_new_transaction(destination)
-
- @defer.inlineCallbacks
- def set_retrying(self, destination, retry_interval):
- # track that this destination is having problems and we should
- # give it a chance to recover before trying it again
-
- if retry_interval:
- retry_interval *= 2
- # plateau at hourly retries for now
- if retry_interval >= 60 * 60 * 1000:
- retry_interval = 60 * 60 * 1000
- else:
- retry_interval = 2000 # try again at first after 2 seconds
-
- yield self.store.set_destination_retry_timings(
- destination,
- int(self._clock.time_msec()),
- retry_interval
- )
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
new file mode 100644
index 0000000000..9d4f2c09a2
--- /dev/null
+++ b/synapse/federation/transaction_queue.py
@@ -0,0 +1,317 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from .persistence import TransactionActions
+from .units import Transaction
+
+from synapse.util.logutils import log_function
+from synapse.util.logcontext import PreserveLoggingContext
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class TransactionQueue(object):
+ """This class makes sure we only have one transaction in flight at
+ a time for a given destination.
+
+ It batches pending PDUs into single transactions.
+ """
+
+ def __init__(self, hs, transport_layer):
+ self.server_name = hs.hostname
+
+ self.store = hs.get_datastore()
+ self.transaction_actions = TransactionActions(self.store)
+
+ self.transport_layer = transport_layer
+
+ self._clock = hs.get_clock()
+
+ # Is a mapping from destinations -> deferreds. Used to keep track
+ # of which destinations have transactions in flight and when they are
+ # done
+ self.pending_transactions = {}
+
+ # Is a mapping from destination -> list of
+ # tuple(pending pdus, deferred, order)
+ self.pending_pdus_by_dest = {}
+ # destination -> list of tuple(edu, deferred)
+ self.pending_edus_by_dest = {}
+
+ # destination -> list of tuple(failure, deferred)
+ self.pending_failures_by_dest = {}
+
+ # HACK to get unique tx id
+ self._next_txn_id = int(self._clock.time_msec())
+
+ @defer.inlineCallbacks
+ @log_function
+ def enqueue_pdu(self, pdu, destinations, order):
+ # We loop through all destinations to see whether we already have
+ # a transaction in progress. If we do, stick it in the pending_pdus
+ # table and we'll get back to it later.
+
+ destinations = set(destinations)
+ destinations.discard(self.server_name)
+ destinations.discard("localhost")
+
+ logger.debug("Sending to: %s", str(destinations))
+
+ if not destinations:
+ return
+
+ deferreds = []
+
+ for destination in destinations:
+ deferred = defer.Deferred()
+ self.pending_pdus_by_dest.setdefault(destination, []).append(
+ (pdu, deferred, order)
+ )
+
+ def eb(failure):
+ if not deferred.called:
+ deferred.errback(failure)
+ else:
+ logger.warn("Failed to send pdu", failure)
+
+ with PreserveLoggingContext():
+ self._attempt_new_transaction(destination).addErrback(eb)
+
+ deferreds.append(deferred)
+
+ yield defer.DeferredList(deferreds)
+
+ # NO inlineCallbacks
+ def enqueue_edu(self, edu):
+ destination = edu.destination
+
+ if destination == self.server_name:
+ return
+
+ deferred = defer.Deferred()
+ self.pending_edus_by_dest.setdefault(destination, []).append(
+ (edu, deferred)
+ )
+
+ def eb(failure):
+ if not deferred.called:
+ deferred.errback(failure)
+ else:
+ logger.warn("Failed to send edu", failure)
+
+ with PreserveLoggingContext():
+ self._attempt_new_transaction(destination).addErrback(eb)
+
+ return deferred
+
+ @defer.inlineCallbacks
+ def enqueue_failure(self, failure, destination):
+ deferred = defer.Deferred()
+
+ self.pending_failures_by_dest.setdefault(
+ destination, []
+ ).append(
+ (failure, deferred)
+ )
+
+ yield deferred
+
+ @defer.inlineCallbacks
+ @log_function
+ def _attempt_new_transaction(self, destination):
+
+ (retry_last_ts, retry_interval) = (0, 0)
+ retry_timings = yield self.store.get_destination_retry_timings(
+ destination
+ )
+ if retry_timings:
+ (retry_last_ts, retry_interval) = (
+ retry_timings.retry_last_ts, retry_timings.retry_interval
+ )
+ if retry_last_ts + retry_interval > int(self._clock.time_msec()):
+ logger.info(
+ "TX [%s] not ready for retry yet - "
+ "dropping transaction for now",
+ destination,
+ )
+ return
+ else:
+ logger.info("TX [%s] is ready for retry", destination)
+
+ logger.info("TX [%s] _attempt_new_transaction", destination)
+
+ if destination in self.pending_transactions:
+ # XXX: pending_transactions can get stuck on by a never-ending
+ # request at which point pending_pdus_by_dest just keeps growing.
+ # we need application-layer timeouts of some flavour of these
+ # requests
+ return
+
+ # list of (pending_pdu, deferred, order)
+ pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+ pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_failures = self.pending_failures_by_dest.pop(destination, [])
+
+ if pending_pdus:
+ logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ destination, len(pending_pdus))
+
+ if not pending_pdus and not pending_edus and not pending_failures:
+ return
+
+ logger.debug(
+ "TX [%s] Attempting new transaction"
+ " (pdus: %d, edus: %d, failures: %d)",
+ destination,
+ len(pending_pdus),
+ len(pending_edus),
+ len(pending_failures)
+ )
+
+ # Sort based on the order field
+ pending_pdus.sort(key=lambda t: t[2])
+
+ pdus = [x[0] for x in pending_pdus]
+ edus = [x[0] for x in pending_edus]
+ failures = [x[0].get_dict() for x in pending_failures]
+ deferreds = [
+ x[1]
+ for x in pending_pdus + pending_edus + pending_failures
+ ]
+
+ try:
+ self.pending_transactions[destination] = 1
+
+ logger.debug("TX [%s] Persisting transaction...", destination)
+
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self._clock.time_msec()),
+ transaction_id=str(self._next_txn_id),
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ pdu_failures=failures,
+ )
+
+ self._next_txn_id += 1
+
+ yield self.transaction_actions.prepare_to_send(transaction)
+
+ logger.debug("TX [%s] Persisted transaction", destination)
+ logger.info(
+ "TX [%s] Sending transaction [%s]",
+ destination,
+ transaction.transaction_id,
+ )
+
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self._clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ code, response = yield self.transport_layer.send_transaction(
+ transaction, json_data_cb
+ )
+
+ logger.info("TX [%s] got %d response", destination, code)
+
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
+
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
+
+ logger.debug("TX [%s] Marked as delivered", destination)
+ logger.debug("TX [%s] Yielding to callbacks...", destination)
+
+ for deferred in deferreds:
+ if code == 200:
+ if retry_last_ts:
+ # this host is alive! reset retry schedule
+ yield self.store.set_destination_retry_timings(
+ destination, 0, 0
+ )
+ deferred.callback(None)
+ else:
+ self.set_retrying(destination, retry_interval)
+ deferred.errback(RuntimeError("Got status %d" % code))
+
+ # Ensures we don't continue until all callbacks on that
+ # deferred have fired
+ try:
+ yield deferred
+ except:
+ pass
+
+ logger.debug("TX [%s] Yielded to callbacks", destination)
+
+ except Exception as e:
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.warn(
+ "TX [%s] Problem in _attempt_transaction: %s",
+ destination,
+ e,
+ )
+
+ self.set_retrying(destination, retry_interval)
+
+ for deferred in deferreds:
+ if not deferred.called:
+ deferred.errback(e)
+
+ finally:
+ # We want to be *very* sure we delete this after we stop processing
+ self.pending_transactions.pop(destination, None)
+
+ # Check to see if there is anything else to send.
+ self._attempt_new_transaction(destination)
+
+ @defer.inlineCallbacks
+ def set_retrying(self, destination, retry_interval):
+ # track that this destination is having problems and we should
+ # give it a chance to recover before trying it again
+
+ if retry_interval:
+ retry_interval *= 2
+ # plateau at hourly retries for now
+ if retry_interval >= 60 * 60 * 1000:
+ retry_interval = 60 * 60 * 1000
+ else:
+ retry_interval = 2000 # try again at first after 2 seconds
+
+ yield self.store.set_destination_retry_timings(
+ destination,
+ int(self._clock.time_msec()),
+ retry_interval
+ )
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index e634a3a213..4cb1dea2de 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -213,3 +213,19 @@ class TransportLayerClient(object):
)
defer.returnValue(response)
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_query_auth(self, destination, room_id, event_id, content):
+ path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+
+ code, content = yield self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ )
+
+ if not 200 <= code < 300:
+ raise RuntimeError("Got %d from send_invite", code)
+
+ defer.returnValue(json.loads(content))
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a380a6910b..9c9f8d525b 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -42,7 +42,7 @@ class TransportLayerServer(object):
content = None
origin = None
- if request.method == "PUT":
+ if request.method in ["PUT", "POST"]:
# TODO: Handle other method types? other content types?
try:
content_bytes = request.content.read()
@@ -234,6 +234,16 @@ class TransportLayerServer(object):
)
)
)
+ self.server.register_path(
+ "POST",
+ re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
+ self._with_authentication(
+ lambda origin, content, query, context, event_id:
+ self._on_query_auth_request(
+ origin, content, event_id,
+ )
+ )
+ )
@defer.inlineCallbacks
@log_function
@@ -325,3 +335,12 @@ class TransportLayerServer(object):
)
defer.returnValue((200, content))
+
+ @defer.inlineCallbacks
+ @log_function
+ def _on_query_auth_request(self, origin, content, event_id):
+ new_content = yield self.request_handler.on_query_auth_request(
+ origin, content, event_id
+ )
+
+ defer.returnValue((200, new_content))
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 96a9b143ca..b31518bf62 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -27,6 +27,7 @@ from .directory import DirectoryHandler
from .typing import TypingNotificationHandler
from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
+from .sync import SyncHandler
class Handlers(object):
@@ -53,3 +54,4 @@ class Handlers(object):
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs)
self.appservice_handler = ApplicationServicesHandler(hs)
+ self.sync_handler = SyncHandler(hs)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d997917cd6..025e7e7e62 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -49,24 +49,25 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0,
- as_client_event=True):
+ as_client_event=True, affect_presence=True):
auth_user = UserID.from_string(auth_user_id)
try:
- if auth_user not in self._streams_per_user:
- self._streams_per_user[auth_user] = 0
- if auth_user in self._stop_timer_per_user:
- try:
- self.clock.cancel_call_later(
- self._stop_timer_per_user.pop(auth_user)
+ if affect_presence:
+ if auth_user not in self._streams_per_user:
+ self._streams_per_user[auth_user] = 0
+ if auth_user in self._stop_timer_per_user:
+ try:
+ self.clock.cancel_call_later(
+ self._stop_timer_per_user.pop(auth_user)
+ )
+ except:
+ logger.exception("Failed to cancel event timer")
+ else:
+ yield self.distributor.fire(
+ "started_user_eventstream", auth_user
)
- except:
- logger.exception("Failed to cancel event timer")
- else:
- yield self.distributor.fire(
- "started_user_eventstream", auth_user
- )
- self._streams_per_user[auth_user] += 1
+ self._streams_per_user[auth_user] += 1
if pagin_config.from_token is None:
pagin_config.from_token = None
@@ -94,27 +95,28 @@ class EventStreamHandler(BaseHandler):
defer.returnValue(chunk)
finally:
- self._streams_per_user[auth_user] -= 1
- if not self._streams_per_user[auth_user]:
- del self._streams_per_user[auth_user]
-
- # 10 seconds of grace to allow the client to reconnect again
- # before we think they're gone
- def _later():
- logger.debug(
- "_later stopped_user_eventstream %s", auth_user
- )
+ if affect_presence:
+ self._streams_per_user[auth_user] -= 1
+ if not self._streams_per_user[auth_user]:
+ del self._streams_per_user[auth_user]
+
+ # 10 seconds of grace to allow the client to reconnect again
+ # before we think they're gone
+ def _later():
+ logger.debug(
+ "_later stopped_user_eventstream %s", auth_user
+ )
- self._stop_timer_per_user.pop(auth_user, None)
+ self._stop_timer_per_user.pop(auth_user, None)
- yield self.distributor.fire(
- "stopped_user_eventstream", auth_user
- )
+ return self.distributor.fire(
+ "stopped_user_eventstream", auth_user
+ )
- logger.debug("Scheduling _later: for %s", auth_user)
- self._stop_timer_per_user[auth_user] = (
- self.clock.call_later(30, _later)
- )
+ logger.debug("Scheduling _later: for %s", auth_user)
+ self._stop_timer_per_user[auth_user] = (
+ self.clock.call_later(30, _later)
+ )
class EventHandler(BaseHandler):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bcdcc90a18..8bf5a4cc11 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,19 +17,16 @@
from ._base import BaseHandler
-from synapse.events.utils import prune_event
from synapse.api.errors import (
- AuthError, FederationError, SynapseError, StoreError,
+ AuthError, FederationError, StoreError,
)
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import (
- compute_event_signature, check_event_content_hash,
- add_hashes_and_signatures,
+ compute_event_signature, add_hashes_and_signatures,
)
from synapse.types import UserID
-from syutil.jsonutil import encode_canonical_json
from twisted.internet import defer
@@ -113,33 +110,6 @@ class FederationHandler(BaseHandler):
logger.debug("Processing event: %s", event.event_id)
- redacted_event = prune_event(event)
-
- redacted_pdu_json = redacted_event.get_pdu_json()
- try:
- yield self.keyring.verify_json_for_server(
- event.origin, redacted_pdu_json
- )
- except SynapseError as e:
- logger.warn(
- "Signature check failed for %s redacted to %s",
- encode_canonical_json(pdu.get_pdu_json()),
- encode_canonical_json(redacted_pdu_json),
- )
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=event.event_id,
- )
-
- if not check_event_content_hash(event):
- logger.warn(
- "Event content has been tampered, redacting %s, %s",
- event.event_id, encode_canonical_json(event.get_dict())
- )
- event = redacted_event
-
logger.debug("Event: %s", event)
# FIXME (erikj): Awful hack to make the case where we are not currently
@@ -149,41 +119,20 @@ class FederationHandler(BaseHandler):
event.room_id,
self.server_name
)
- if not is_in_room and not event.internal_metadata.outlier:
+ if not is_in_room and not event.internal_metadata.is_outlier():
logger.debug("Got event for room we're not in.")
-
- replication = self.replication_layer
-
- if not state:
- state, auth_chain = yield replication.get_state_for_room(
- origin, context=event.room_id, event_id=event.event_id,
- )
-
- if not auth_chain:
- auth_chain = yield replication.get_event_auth(
- origin,
- context=event.room_id,
- event_id=event.event_id,
- )
-
- for e in auth_chain:
- e.internal_metadata.outlier = True
- try:
- yield self._handle_new_event(e, fetch_auth_from=origin)
- except:
- logger.exception(
- "Failed to handle auth event %s",
- e.event_id,
- )
-
current_state = state
- if state:
+ if state and auth_chain is not None:
for e in state:
- logging.info("A :) %r", e)
e.internal_metadata.outlier = True
try:
- yield self._handle_new_event(e)
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+ yield self._handle_new_event(origin, e, auth_events=auth)
except:
logger.exception(
"Failed to handle state event %s",
@@ -192,6 +141,7 @@ class FederationHandler(BaseHandler):
try:
yield self._handle_new_event(
+ origin,
event,
state=state,
backfilled=backfilled,
@@ -393,8 +343,19 @@ class FederationHandler(BaseHandler):
for e in auth_chain:
e.internal_metadata.outlier = True
+
+ if e.event_id == event.event_id:
+ continue
+
try:
- yield self._handle_new_event(e)
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+ yield self._handle_new_event(
+ target_host, e, auth_events=auth
+ )
except:
logger.exception(
"Failed to handle auth event %s",
@@ -402,11 +363,18 @@ class FederationHandler(BaseHandler):
)
for e in state:
- # FIXME: Auth these.
+ if e.event_id == event.event_id:
+ continue
+
e.internal_metadata.outlier = True
try:
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
yield self._handle_new_event(
- e, fetch_auth_from=target_host
+ target_host, e, auth_events=auth
)
except:
logger.exception(
@@ -414,10 +382,18 @@ class FederationHandler(BaseHandler):
e.event_id,
)
+ auth_ids = [e_id for e_id, _ in event.auth_events]
+ auth_events = {
+ (e.type, e.state_key): e for e in auth_chain
+ if e.event_id in auth_ids
+ }
+
yield self._handle_new_event(
+ target_host,
new_event,
state=state,
current_state=state,
+ auth_events=auth_events,
)
yield self.notifier.on_new_room_event(
@@ -481,7 +457,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
- context = yield self._handle_new_event(event)
+ context = yield self._handle_new_event(origin, event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -682,11 +658,12 @@ class FederationHandler(BaseHandler):
waiters.pop().callback(None)
@defer.inlineCallbacks
- def _handle_new_event(self, event, state=None, backfilled=False,
- current_state=None, fetch_auth_from=None):
+ @log_function
+ def _handle_new_event(self, origin, event, state=None, backfilled=False,
+ current_state=None, auth_events=None):
logger.debug(
- "_handle_new_event: Before annotate: %s, sigs: %s",
+ "_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures,
)
@@ -694,65 +671,44 @@ class FederationHandler(BaseHandler):
event, old_state=state
)
+ if not auth_events:
+ auth_events = context.auth_events
+
logger.debug(
- "_handle_new_event: Before auth fetch: %s, sigs: %s",
- event.event_id, event.signatures,
+ "_handle_new_event: %s, auth_events: %s",
+ event.event_id, auth_events,
)
is_new_state = not event.internal_metadata.is_outlier()
- known_ids = set(
- [s.event_id for s in context.auth_events.values()]
- )
-
- for e_id, _ in event.auth_events:
- if e_id not in known_ids:
- e = yield self.store.get_event(e_id, allow_none=True)
-
- if not e and fetch_auth_from is not None:
- # Grab the auth_chain over federation if we are missing
- # auth events.
- auth_chain = yield self.replication_layer.get_event_auth(
- fetch_auth_from, event.event_id, event.room_id
- )
- for auth_event in auth_chain:
- yield self._handle_new_event(auth_event)
- e = yield self.store.get_event(e_id, allow_none=True)
-
- if not e:
- # TODO: Do some conflict res to make sure that we're
- # not the ones who are wrong.
- logger.info(
- "Rejecting %s as %s not in db or %s",
- event.event_id, e_id, known_ids,
- )
- # FIXME: How does raising AuthError work with federation?
- raise AuthError(403, "Cannot find auth event")
-
- context.auth_events[(e.type, e.state_key)] = e
-
- logger.debug(
- "_handle_new_event: Before hack: %s, sigs: %s",
- event.event_id, event.signatures,
- )
-
+ # This is a hack to fix some old rooms where the initial join event
+ # didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
- context.auth_events[(c.type, c.state_key)] = c
+ auth_events[(c.type, c.state_key)] = c
- logger.debug(
- "_handle_new_event: Before auth check: %s, sigs: %s",
- event.event_id, event.signatures,
- )
+ try:
+ yield self.do_auth(
+ origin, event, context, auth_events=auth_events
+ )
+ except AuthError as e:
+ logger.warn(
+ "Rejecting %s because %s",
+ event.event_id, e.msg
+ )
- self.auth.check(event, auth_events=context.auth_events)
+ context.rejected = RejectedReason.AUTH_ERROR
- logger.debug(
- "_handle_new_event: Before persist_event: %s, sigs: %s",
- event.event_id, event.signatures,
- )
+ yield self.store.persist_event(
+ event,
+ context=context,
+ backfilled=backfilled,
+ is_new_state=False,
+ current_state=current_state,
+ )
+ raise
yield self.store.persist_event(
event,
@@ -762,9 +718,294 @@ class FederationHandler(BaseHandler):
current_state=current_state,
)
- logger.debug(
- "_handle_new_event: After persist_event: %s, sigs: %s",
- event.event_id, event.signatures,
+ defer.returnValue(context)
+
+ @defer.inlineCallbacks
+ def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
+ missing):
+ # Just go through and process each event in `remote_auth_chain`. We
+ # don't want to fall into the trap of `missing` being wrong.
+ for e in remote_auth_chain:
+ try:
+ yield self._handle_new_event(origin, e)
+ except AuthError:
+ pass
+
+ # Now get the current auth_chain for the event.
+ local_auth_chain = yield self.store.get_auth_chain([event_id])
+
+ # TODO: Check if we would now reject event_id. If so we need to tell
+ # everyone.
+
+ ret = yield self.construct_auth_difference(
+ local_auth_chain, remote_auth_chain
)
- defer.returnValue(context)
+ for event in ret["auth_chain"]:
+ event.signatures.update(
+ compute_event_signature(
+ event,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
+ )
+
+ logger.debug("on_query_auth reutrning: %s", ret)
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ @log_function
+ def do_auth(self, origin, event, context, auth_events):
+ # Check if we have all the auth events.
+ res = yield self.store.have_events(
+ [e_id for e_id, _ in event.auth_events]
+ )
+
+ event_auth_events = set(e_id for e_id, _ in event.auth_events)
+ seen_events = set(res.keys())
+
+ missing_auth = event_auth_events - seen_events
+
+ if missing_auth:
+ logger.debug("Missing auth: %s", missing_auth)
+ # If we don't have all the auth events, we need to get them.
+ remote_auth_chain = yield self.replication_layer.get_event_auth(
+ origin, event.room_id, event.event_id
+ )
+
+ seen_remotes = yield self.store.have_events(
+ [e.event_id for e in remote_auth_chain]
+ )
+
+ for e in remote_auth_chain:
+ if e.event_id in seen_remotes.keys():
+ continue
+
+ if e.event_id == event.event_id:
+ continue
+
+ try:
+ auth_ids = [e_id for e_id, _ in e.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in remote_auth_chain
+ if e.event_id in auth_ids
+ }
+ e.internal_metadata.outlier = True
+
+ logger.debug(
+ "do_auth %s missing_auth: %s",
+ event.event_id, e.event_id
+ )
+ yield self._handle_new_event(
+ origin, e, auth_events=auth
+ )
+
+ if e.event_id in event_auth_events:
+ auth_events[(e.type, e.state_key)] = e
+ except AuthError:
+ pass
+
+ # FIXME: Assumes we have and stored all the state for all the
+ # prev_events
+ current_state = set(e.event_id for e in auth_events.values())
+ different_auth = event_auth_events - current_state
+
+ if different_auth and not event.internal_metadata.is_outlier():
+ # Do auth conflict res.
+ logger.debug("Different auth: %s", different_auth)
+
+ # 1. Get what we think is the auth chain.
+ auth_ids = self.auth.compute_auth_events(event, context)
+ local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+
+ # 2. Get remote difference.
+ result = yield self.replication_layer.query_auth(
+ origin,
+ event.room_id,
+ event.event_id,
+ local_auth_chain,
+ )
+
+ seen_remotes = yield self.store.have_events(
+ [e.event_id for e in result["auth_chain"]]
+ )
+
+ # 3. Process any remote auth chain events we haven't seen.
+ for ev in result["auth_chain"]:
+ if ev.event_id in seen_remotes.keys():
+ continue
+
+ if ev.event_id == event.event_id:
+ continue
+
+ try:
+ auth_ids = [e_id for e_id, _ in ev.auth_events]
+ auth = {
+ (e.type, e.state_key): e for e in result["auth_chain"]
+ if e.event_id in auth_ids
+ }
+ ev.internal_metadata.outlier = True
+
+ logger.debug(
+ "do_auth %s different_auth: %s",
+ event.event_id, e.event_id
+ )
+
+ yield self._handle_new_event(
+ origin, ev, auth_events=auth
+ )
+
+ if ev.event_id in event_auth_events:
+ auth_events[(ev.type, ev.state_key)] = ev
+ except AuthError:
+ pass
+
+ # 4. Look at rejects and their proofs.
+ # TODO.
+
+ context.current_state.update(auth_events)
+ context.state_group = None
+
+ try:
+ self.auth.check(event, auth_events=auth_events)
+ except AuthError:
+ raise
+
+ @defer.inlineCallbacks
+ def construct_auth_difference(self, local_auth, remote_auth):
+ """ Given a local and remote auth chain, find the differences. This
+ assumes that we have already processed all events in remote_auth
+
+ Params:
+ local_auth (list)
+ remote_auth (list)
+
+ Returns:
+ dict
+ """
+
+ logger.debug("construct_auth_difference Start!")
+
+ # TODO: Make sure we are OK with local_auth or remote_auth having more
+ # auth events in them than strictly necessary.
+
+ def sort_fun(ev):
+ return ev.depth, ev.event_id
+
+ logger.debug("construct_auth_difference after sort_fun!")
+
+ # We find the differences by starting at the "bottom" of each list
+ # and iterating up on both lists. The lists are ordered by depth and
+ # then event_id, we iterate up both lists until we find the event ids
+ # don't match. Then we look at depth/event_id to see which side is
+ # missing that event, and iterate only up that list. Repeat.
+
+ remote_list = list(remote_auth)
+ remote_list.sort(key=sort_fun)
+
+ local_list = list(local_auth)
+ local_list.sort(key=sort_fun)
+
+ local_iter = iter(local_list)
+ remote_iter = iter(remote_list)
+
+ logger.debug("construct_auth_difference before get_next!")
+
+ def get_next(it, opt=None):
+ try:
+ return it.next()
+ except:
+ return opt
+
+ current_local = get_next(local_iter)
+ current_remote = get_next(remote_iter)
+
+ logger.debug("construct_auth_difference before while")
+
+ missing_remotes = []
+ missing_locals = []
+ while current_local or current_remote:
+ if current_remote is None:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+ continue
+
+ if current_local is None:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ if current_local.event_id == current_remote.event_id:
+ current_local = get_next(local_iter)
+ current_remote = get_next(remote_iter)
+ continue
+
+ if current_local.depth < current_remote.depth:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+ continue
+
+ if current_local.depth > current_remote.depth:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ # They have the same depth, so we fall back to the event_id order
+ if current_local.event_id < current_remote.event_id:
+ missing_locals.append(current_local)
+ current_local = get_next(local_iter)
+
+ if current_local.event_id > current_remote.event_id:
+ missing_remotes.append(current_remote)
+ current_remote = get_next(remote_iter)
+ continue
+
+ logger.debug("construct_auth_difference after while")
+
+ # missing locals should be sent to the server
+ # We should find why we are missing remotes, as they will have been
+ # rejected.
+
+ # Remove events from missing_remotes if they are referencing a missing
+ # remote. We only care about the "root" rejected ones.
+ missing_remote_ids = [e.event_id for e in missing_remotes]
+ base_remote_rejected = list(missing_remotes)
+ for e in missing_remotes:
+ for e_id, _ in e.auth_events:
+ if e_id in missing_remote_ids:
+ base_remote_rejected.remove(e)
+
+ reason_map = {}
+
+ for e in base_remote_rejected:
+ reason = yield self.store.get_rejection_reason(e.event_id)
+ if reason is None:
+ # FIXME: ERRR?!
+ logger.warn("Could not find reason for %s", e.event_id)
+ raise RuntimeError("")
+
+ reason_map[e.event_id] = reason
+
+ if reason == RejectedReason.AUTH_ERROR:
+ pass
+ elif reason == RejectedReason.REPLACED:
+ # TODO: Get proof
+ pass
+ elif reason == RejectedReason.NOT_ANCESTOR:
+ # TODO: Get proof.
+ pass
+
+ logger.debug("construct_auth_difference returning")
+
+ defer.returnValue({
+ "auth_chain": local_auth,
+ "rejects": {
+ e.event_id: {
+ "reason": reason_map[e.event_id],
+ "proof": None,
+ }
+ for e in base_remote_rejected
+ },
+ "missing": [e.event_id for e in missing_locals],
+ })
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9c3271fe88..6fbd2af4ab 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -114,7 +114,8 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
- def create_and_send_event(self, event_dict, ratelimit=True):
+ def create_and_send_event(self, event_dict, ratelimit=True,
+ client=None, txn_id=None):
""" Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
@@ -148,6 +149,15 @@ class MessageHandler(BaseHandler):
builder.content
)
+ if client is not None:
+ if client.token_id is not None:
+ builder.internal_metadata.token_id = client.token_id
+ if client.device_id is not None:
+ builder.internal_metadata.device_id = client.device_id
+
+ if txn_id is not None:
+ builder.internal_metadata.txn_id = txn_id
+
event, context = yield self._create_new_client_event(
builder=builder,
)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index d66bfea7b1..cd0798c2b0 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -87,6 +87,10 @@ class PresenceHandler(BaseHandler):
"changed_presencelike_data", self.changed_presencelike_data
)
+ # outbound signal from the presence module to advertise when a user's
+ # presence has changed
+ distributor.declare("user_presence_changed")
+
self.distributor = distributor
self.federation = hs.get_replication_layer()
@@ -604,6 +608,7 @@ class PresenceHandler(BaseHandler):
room_ids=room_ids,
statuscache=statuscache,
)
+ yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 732652c228..66a89c10b2 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -163,7 +163,7 @@ class RegistrationHandler(BaseHandler):
# each request
httpCli = SimpleHttpClient(self.hs)
# XXX: make this configurable!
- trustedIdServers = ['matrix.org:8090']
+ trustedIdServers = ['matrix.org:8090', 'matrix.org']
if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer'])
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
new file mode 100644
index 0000000000..962686f4bb
--- /dev/null
+++ b/synapse/handlers/sync.py
@@ -0,0 +1,434 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import BaseHandler
+
+from synapse.streams.config import PaginationConfig
+from synapse.api.constants import Membership, EventTypes
+
+from twisted.internet import defer
+
+import collections
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+SyncConfig = collections.namedtuple("SyncConfig", [
+ "user",
+ "client_info",
+ "limit",
+ "gap",
+ "sort",
+ "backfill",
+ "filter",
+])
+
+
+class RoomSyncResult(collections.namedtuple("RoomSyncResult", [
+ "room_id",
+ "limited",
+ "published",
+ "events",
+ "state",
+ "prev_batch",
+ "ephemeral",
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ """Make the result appear empty if there are no updates. This is used
+ to tell if room needs to be part of the sync result.
+ """
+ return bool(self.events or self.state or self.ephemeral)
+
+
+class SyncResult(collections.namedtuple("SyncResult", [
+ "next_batch", # Token for the next sync
+ "private_user_data", # List of private events for the user.
+ "public_user_data", # List of public events for all users.
+ "rooms", # RoomSyncResult for each room.
+])):
+ __slots__ = []
+
+ def __nonzero__(self):
+ """Make the result appear empty if there are no updates. This is used
+ to tell if the notifier needs to wait for more events when polling for
+ events.
+ """
+ return bool(
+ self.private_user_data or self.public_user_data or self.rooms
+ )
+
+
+class SyncHandler(BaseHandler):
+
+ def __init__(self, hs):
+ super(SyncHandler, self).__init__(hs)
+ self.event_sources = hs.get_event_sources()
+ self.clock = hs.get_clock()
+
+ @defer.inlineCallbacks
+ def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0):
+ """Get the sync for a client if we have new data for it now. Otherwise
+ wait for new data to arrive on the server. If the timeout expires, then
+ return an empty sync result.
+ Returns:
+ A Deferred SyncResult.
+ """
+ if timeout == 0 or since_token is None:
+ result = yield self.current_sync_for_user(sync_config, since_token)
+ defer.returnValue(result)
+ else:
+ def current_sync_callback():
+ return self.current_sync_for_user(sync_config, since_token)
+
+ rm_handler = self.hs.get_handlers().room_member_handler
+ room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
+ result = yield self.notifier.wait_for_events(
+ sync_config.user, room_ids,
+ sync_config.filter, timeout, current_sync_callback
+ )
+ defer.returnValue(result)
+
+ def current_sync_for_user(self, sync_config, since_token=None):
+ """Get the sync for client needed to match what the server has now.
+ Returns:
+ A Deferred SyncResult.
+ """
+ if since_token is None:
+ return self.initial_sync(sync_config)
+ else:
+ if sync_config.gap:
+ return self.incremental_sync_with_gap(sync_config, since_token)
+ else:
+ #TODO(mjark): Handle gapless sync
+ raise NotImplementedError()
+
+ @defer.inlineCallbacks
+ def initial_sync(self, sync_config):
+ """Get a sync for a client which is starting without any state
+ Returns:
+ A Deferred SyncResult.
+ """
+ if sync_config.sort == "timeline,desc":
+ # TODO(mjark): Handle going through events in reverse order?.
+ # What does "most recent events" mean when applying the limits mean
+ # in this case?
+ raise NotImplementedError()
+
+ now_token = yield self.event_sources.get_current_token()
+
+ presence_stream = self.event_sources.sources["presence"]
+ # TODO (mjark): This looks wrong, shouldn't we be getting the presence
+ # UP to the present rather than after the present?
+ pagination_config = PaginationConfig(from_token=now_token)
+ presence, _ = yield presence_stream.get_pagination_rows(
+ user=sync_config.user,
+ pagination_config=pagination_config.get_source_config("presence"),
+ key=None
+ )
+ room_list = yield self.store.get_rooms_for_user_where_membership_is(
+ user_id=sync_config.user.to_string(),
+ membership_list=[Membership.INVITE, Membership.JOIN]
+ )
+
+ # TODO (mjark): Does public mean "published"?
+ published_rooms = yield self.store.get_rooms(is_public=True)
+ published_room_ids = set(r["room_id"] for r in published_rooms)
+
+ rooms = []
+ for event in room_list:
+ room_sync = yield self.initial_sync_for_room(
+ event.room_id, sync_config, now_token, published_room_ids
+ )
+ rooms.append(room_sync)
+
+ defer.returnValue(SyncResult(
+ public_user_data=presence,
+ private_user_data=[],
+ rooms=rooms,
+ next_batch=now_token,
+ ))
+
+ @defer.inlineCallbacks
+ def initial_sync_for_room(self, room_id, sync_config, now_token,
+ published_room_ids):
+ """Sync a room for a client which is starting without any state
+ Returns:
+ A Deferred RoomSyncResult.
+ """
+
+ recents, prev_batch_token, limited = yield self.load_filtered_recents(
+ room_id, sync_config, now_token,
+ )
+
+ current_state_events = yield self.state_handler.get_current_state(
+ room_id
+ )
+
+ defer.returnValue(RoomSyncResult(
+ room_id=room_id,
+ published=room_id in published_room_ids,
+ events=recents,
+ prev_batch=prev_batch_token,
+ state=current_state_events,
+ limited=limited,
+ ephemeral=[],
+ ))
+
+ @defer.inlineCallbacks
+ def incremental_sync_with_gap(self, sync_config, since_token):
+ """ Get the incremental delta needed to bring the client up to
+ date with the server.
+ Returns:
+ A Deferred SyncResult.
+ """
+ if sync_config.sort == "timeline,desc":
+ # TODO(mjark): Handle going through events in reverse order?.
+ # What does "most recent events" mean when applying the limits mean
+ # in this case?
+ raise NotImplementedError()
+
+ now_token = yield self.event_sources.get_current_token()
+
+ presence_source = self.event_sources.sources["presence"]
+ presence, presence_key = yield presence_source.get_new_events_for_user(
+ user=sync_config.user,
+ from_key=since_token.presence_key,
+ limit=sync_config.limit,
+ )
+ now_token = now_token.copy_and_replace("presence_key", presence_key)
+
+ typing_source = self.event_sources.sources["typing"]
+ typing, typing_key = yield typing_source.get_new_events_for_user(
+ user=sync_config.user,
+ from_key=since_token.typing_key,
+ limit=sync_config.limit,
+ )
+ now_token = now_token.copy_and_replace("typing_key", typing_key)
+
+ typing_by_room = {event["room_id"]: [event] for event in typing}
+ for event in typing:
+ event.pop("room_id")
+ logger.debug("Typing %r", typing_by_room)
+
+ rm_handler = self.hs.get_handlers().room_member_handler
+ room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
+
+ # TODO (mjark): Does public mean "published"?
+ published_rooms = yield self.store.get_rooms(is_public=True)
+ published_room_ids = set(r["room_id"] for r in published_rooms)
+
+ room_events, _ = yield self.store.get_room_events_stream(
+ sync_config.user.to_string(),
+ from_key=since_token.room_key,
+ to_key=now_token.room_key,
+ room_id=None,
+ limit=sync_config.limit + 1,
+ )
+
+ rooms = []
+ if len(room_events) <= sync_config.limit:
+ # There is no gap in any of the rooms. Therefore we can just
+ # partition the new events by room and return them.
+ events_by_room_id = {}
+ for event in room_events:
+ events_by_room_id.setdefault(event.room_id, []).append(event)
+
+ for room_id in room_ids:
+ recents = events_by_room_id.get(room_id, [])
+ state = [event for event in recents if event.is_state()]
+ if recents:
+ prev_batch = now_token.copy_and_replace(
+ "room_key", recents[0].internal_metadata.before
+ )
+ else:
+ prev_batch = now_token
+
+ state = yield self.check_joined_room(
+ sync_config, room_id, state
+ )
+
+ room_sync = RoomSyncResult(
+ room_id=room_id,
+ published=room_id in published_room_ids,
+ events=recents,
+ prev_batch=prev_batch,
+ state=state,
+ limited=False,
+ ephemeral=typing_by_room.get(room_id, [])
+ )
+ if room_sync:
+ rooms.append(room_sync)
+ else:
+ for room_id in room_ids:
+ room_sync = yield self.incremental_sync_with_gap_for_room(
+ room_id, sync_config, since_token, now_token,
+ published_room_ids, typing_by_room
+ )
+ if room_sync:
+ rooms.append(room_sync)
+
+ defer.returnValue(SyncResult(
+ public_user_data=presence,
+ private_user_data=[],
+ rooms=rooms,
+ next_batch=now_token,
+ ))
+
+ @defer.inlineCallbacks
+ def load_filtered_recents(self, room_id, sync_config, now_token,
+ since_token=None):
+ limited = True
+ recents = []
+ filtering_factor = 2
+ load_limit = max(sync_config.limit * filtering_factor, 100)
+ max_repeat = 3 # Only try a few times per room, otherwise
+ room_key = now_token.room_key
+
+ while limited and len(recents) < sync_config.limit and max_repeat:
+ events, keys = yield self.store.get_recent_events_for_room(
+ room_id,
+ limit=load_limit + 1,
+ from_token=since_token.room_key if since_token else None,
+ end_token=room_key,
+ )
+ (room_key, _) = keys
+ loaded_recents = sync_config.filter.filter_room_events(events)
+ loaded_recents.extend(recents)
+ recents = loaded_recents
+ if len(events) <= load_limit:
+ limited = False
+ max_repeat -= 1
+
+ if len(recents) > sync_config.limit:
+ recents = recents[-sync_config.limit:]
+ room_key = recents[0].internal_metadata.before
+
+ prev_batch_token = now_token.copy_and_replace(
+ "room_key", room_key
+ )
+
+ defer.returnValue((recents, prev_batch_token, limited))
+
+ @defer.inlineCallbacks
+ def incremental_sync_with_gap_for_room(self, room_id, sync_config,
+ since_token, now_token,
+ published_room_ids, typing_by_room):
+ """ Get the incremental delta needed to bring the client up to date for
+ the room. Gives the client the most recent events and the changes to
+ state.
+ Returns:
+ A Deferred RoomSyncResult
+ """
+
+ # TODO(mjark): Check for redactions we might have missed.
+
+ recents, prev_batch_token, limited = yield self.load_filtered_recents(
+ room_id, sync_config, now_token, since_token,
+ )
+
+ logging.debug("Recents %r", recents)
+
+ # TODO(mjark): This seems racy since this isn't being passed a
+ # token to indicate what point in the stream this is
+ current_state_events = yield self.state_handler.get_current_state(
+ room_id
+ )
+
+ state_at_previous_sync = yield self.get_state_at_previous_sync(
+ room_id, since_token=since_token
+ )
+
+ state_events_delta = yield self.compute_state_delta(
+ since_token=since_token,
+ previous_state=state_at_previous_sync,
+ current_state=current_state_events,
+ )
+
+ state_events_delta = yield self.check_joined_room(
+ sync_config, room_id, state_events_delta
+ )
+
+ room_sync = RoomSyncResult(
+ room_id=room_id,
+ published=room_id in published_room_ids,
+ events=recents,
+ prev_batch=prev_batch_token,
+ state=state_events_delta,
+ limited=limited,
+ ephemeral=typing_by_room.get(room_id, [])
+ )
+
+ logging.debug("Room sync: %r", room_sync)
+
+ defer.returnValue(room_sync)
+
+ @defer.inlineCallbacks
+ def get_state_at_previous_sync(self, room_id, since_token):
+ """ Get the room state at the previous sync the client made.
+ Returns:
+ A Deferred list of Events.
+ """
+ last_events, token = yield self.store.get_recent_events_for_room(
+ room_id, end_token=since_token.room_key, limit=1,
+ )
+
+ if last_events:
+ last_event = last_events[0]
+ last_context = yield self.state_handler.compute_event_context(
+ last_event
+ )
+ if last_event.is_state():
+ state = [last_event] + last_context.current_state.values()
+ else:
+ state = last_context.current_state.values()
+ else:
+ state = ()
+ defer.returnValue(state)
+
+ def compute_state_delta(self, since_token, previous_state, current_state):
+ """ Works out the differnce in state between the current state and the
+ state the client got when it last performed a sync.
+ Returns:
+ A list of events.
+ """
+ # TODO(mjark) Check if the state events were received by the server
+ # after the previous sync, since we need to include those state
+ # updates even if they occured logically before the previous event.
+ # TODO(mjark) Check for new redactions in the state events.
+ previous_dict = {event.event_id: event for event in previous_state}
+ state_delta = []
+ for event in current_state:
+ if event.event_id not in previous_dict:
+ state_delta.append(event)
+ return state_delta
+
+ @defer.inlineCallbacks
+ def check_joined_room(self, sync_config, room_id, state_delta):
+ joined = False
+ for event in state_delta:
+ if (
+ event.type == EventTypes.Member
+ and event.state_key == sync_config.user.to_string()
+ ):
+ if event.content["membership"] == Membership.JOIN:
+ joined = True
+
+ if joined:
+ state_delta = yield self.state_handler.get_current_state(room_id)
+
+ defer.returnValue(state_delta)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 7793bab106..198f575cfa 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -63,6 +63,25 @@ class SimpleHttpClient(object):
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
+ def post_json_get_json(self, uri, post_json):
+ json_str = json.dumps(post_json)
+
+ logger.info("HTTP POST %s -> %s", json_str, uri)
+
+ response = yield self.agent.request(
+ "POST",
+ uri.encode("ascii"),
+ headers=Headers({
+ "Content-Type": ["application/json"]
+ }),
+ bodyProducer=FileBodyProducer(StringIO(json_str))
+ )
+
+ body = yield readBody(response)
+
+ defer.returnValue(json.loads(body))
+
+ @defer.inlineCallbacks
def get_json(self, uri, args={}):
""" Get's some json from the given host and path
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 1dda3ba2c7..c7bf1b47b8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -245,6 +245,43 @@ class MatrixFederationHttpClient(object):
defer.returnValue((response.code, body))
@defer.inlineCallbacks
+ def post_json(self, destination, path, data={}):
+ """ Sends the specifed json data using POST
+
+ Args:
+ destination (str): The remote server to send the HTTP request
+ to.
+ path (str): The HTTP path.
+ data (dict): A dict containing the data that will be used as
+ the request body. This will be encoded as JSON.
+
+ Returns:
+ Deferred: Succeeds when we get a 2xx HTTP response. The result
+ will be the decoded JSON body. On a 4xx or 5xx error response a
+ CodeMessageException is raised.
+ """
+
+ def body_callback(method, url_bytes, headers_dict):
+ self.sign_request(
+ destination, method, url_bytes, headers_dict, data
+ )
+ return _JsonProducer(data)
+
+ response = yield self._create_request(
+ destination.encode("ascii"),
+ "POST",
+ path.encode("ascii"),
+ body_callback=body_callback,
+ headers_dict={"Content-Type": ["application/json"]},
+ )
+
+ logger.debug("Getting resp body")
+ body = yield readBody(response)
+ logger.debug("Got resp body")
+
+ defer.returnValue((response.code, body))
+
+ @defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" GETs some json from the given host homeserver and path
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 8015a22edf..0f6539e1be 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -16,7 +16,7 @@
from synapse.http.agent_name import AGENT_NAME
from synapse.api.errors import (
- cs_exception, SynapseError, CodeMessageException
+ cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
)
from synapse.util.logcontext import LoggingContext
@@ -139,11 +139,7 @@ class JsonResource(HttpServer, resource.Resource):
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- self._send_response(
- request,
- 400,
- {"error": "Unrecognized request"}
- )
+ raise UnrecognizedRequestError()
except CodeMessageException as e:
if isinstance(e, SynapseError):
logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 3aec1d4af2..e3b6ead620 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor
+from synapse.types import StreamToken
import logging
@@ -205,6 +206,53 @@ class Notifier(object):
[notify(l).addErrback(eb) for l in listeners]
)
+ @defer.inlineCallbacks
+ def wait_for_events(self, user, rooms, filter, timeout, callback):
+ """Wait until the callback returns a non empty response or the
+ timeout fires.
+ """
+
+ deferred = defer.Deferred()
+
+ from_token = StreamToken("s0", "0", "0")
+
+ listener = [_NotificationListener(
+ user=user,
+ rooms=rooms,
+ from_token=from_token,
+ limit=1,
+ timeout=timeout,
+ deferred=deferred,
+ )]
+
+ if timeout:
+ self._register_with_keys(listener[0])
+
+ result = yield callback()
+ if timeout:
+ timed_out = [False]
+
+ def _timeout_listener():
+ timed_out[0] = True
+ listener[0].notify(self, [], from_token, from_token)
+
+ self.clock.call_later(timeout/1000., _timeout_listener)
+ while not result and not timed_out[0]:
+ yield deferred
+ deferred = defer.Deferred()
+ listener[0] = _NotificationListener(
+ user=user,
+ rooms=rooms,
+ from_token=from_token,
+ limit=1,
+ timeout=timeout,
+ deferred=deferred,
+ )
+ self._register_with_keys(listener[0])
+ result = yield callback()
+
+ defer.returnValue(result)
+
def get_events_for(self, user, rooms, pagination_config, timeout):
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
new file mode 100644
index 0000000000..28e5dae81d
--- /dev/null
+++ b/synapse/push/__init__.py
@@ -0,0 +1,410 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.streams.config import PaginationConfig
+from synapse.types import StreamToken, UserID
+
+import synapse.util.async
+import baserules
+
+import logging
+import fnmatch
+import json
+import re
+
+logger = logging.getLogger(__name__)
+
+
+class Pusher(object):
+ INITIAL_BACKOFF = 1000
+ MAX_BACKOFF = 60 * 60 * 1000
+ GIVE_UP_AFTER = 24 * 60 * 60 * 1000
+ DEFAULT_ACTIONS = ['notify']
+
+ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
+
+ def __init__(self, _hs, instance_handle, user_name, app_id,
+ app_display_name, device_display_name, pushkey, pushkey_ts,
+ data, last_token, last_success, failing_since):
+ self.hs = _hs
+ self.evStreamHandler = self.hs.get_handlers().event_stream_handler
+ self.store = self.hs.get_datastore()
+ self.clock = self.hs.get_clock()
+ self.instance_handle = instance_handle
+ self.user_name = user_name
+ self.app_id = app_id
+ self.app_display_name = app_display_name
+ self.device_display_name = device_display_name
+ self.pushkey = pushkey
+ self.pushkey_ts = pushkey_ts
+ self.data = data
+ self.last_token = last_token
+ self.last_success = last_success # not actually used
+ self.backoff_delay = Pusher.INITIAL_BACKOFF
+ self.failing_since = failing_since
+ self.alive = True
+
+ # The last value of last_active_time that we saw
+ self.last_last_active_time = 0
+ self.has_unread = True
+
+ @defer.inlineCallbacks
+ def _actions_for_event(self, ev):
+ """
+ This should take into account notification settings that the user
+ has configured both globally and per-room when we have the ability
+ to do such things.
+ """
+ if ev['user_id'] == self.user_name:
+ # let's assume you probably know about messages you sent yourself
+ defer.returnValue(['dont_notify'])
+
+ if ev['type'] == 'm.room.member':
+ if ev['state_key'] != self.user_name:
+ defer.returnValue(['dont_notify'])
+
+ rules = yield self.store.get_push_rules_for_user_name(self.user_name)
+
+ for r in rules:
+ r['conditions'] = json.loads(r['conditions'])
+ r['actions'] = json.loads(r['actions'])
+
+ user_name_localpart = UserID.from_string(self.user_name).localpart
+
+ rules.extend(baserules.make_base_rules(user_name_localpart))
+
+ # get *our* member event for display name matching
+ member_events_for_room = yield self.store.get_current_state(
+ room_id=ev['room_id'],
+ event_type='m.room.member',
+ state_key=None
+ )
+ my_display_name = None
+ room_member_count = 0
+ for mev in member_events_for_room:
+ if mev.content['membership'] != 'join':
+ continue
+
+ # This loop does two things:
+ # 1) Find our current display name
+ if mev.state_key == self.user_name and 'displayname' in mev.content:
+ my_display_name = mev.content['displayname']
+
+ # and 2) Get the number of people in that room
+ room_member_count += 1
+
+ for r in rules:
+ matches = True
+
+ conditions = r['conditions']
+ actions = r['actions']
+
+ for c in conditions:
+ matches &= self._event_fulfills_condition(
+ ev, c, display_name=my_display_name,
+ room_member_count=room_member_count
+ )
+ # ignore rules with no actions (we have an explict 'dont_notify'
+ if len(actions) == 0:
+ logger.warn(
+ "Ignoring rule id %s with no actions for user %s" %
+ (r['rule_id'], r['user_name'])
+ )
+ continue
+ if matches:
+ defer.returnValue(actions)
+
+ defer.returnValue(Pusher.DEFAULT_ACTIONS)
+
+ def _event_fulfills_condition(self, ev, condition, display_name, room_member_count):
+ if condition['kind'] == 'event_match':
+ if 'pattern' not in condition:
+ logger.warn("event_match condition with no pattern")
+ return False
+ pat = condition['pattern']
+
+ if pat.strip("*?[]") == pat:
+ # no special glob characters so we assume the user means
+ # 'contains this string' rather than 'is this string'
+ pat = "*%s*" % (pat,)
+
+ val = _value_for_dotted_key(condition['key'], ev)
+ if val is None:
+ return False
+ return fnmatch.fnmatch(val.upper(), pat.upper())
+ elif condition['kind'] == 'device':
+ if 'instance_handle' not in condition:
+ return True
+ return condition['instance_handle'] == self.instance_handle
+ elif condition['kind'] == 'contains_display_name':
+ # This is special because display names can be different
+ # between rooms and so you can't really hard code it in a rule.
+ # Optimisation: we should cache these names and update them from
+ # the event stream.
+ if 'content' not in ev or 'body' not in ev['content']:
+ return False
+ if not display_name:
+ return False
+ return fnmatch.fnmatch(
+ ev['content']['body'].upper(), "*%s*" % (display_name.upper(),)
+ )
+ elif condition['kind'] == 'room_member_count':
+ if 'is' not in condition:
+ return False
+ m = Pusher.INEQUALITY_EXPR.match(condition['is'])
+ if not m:
+ return False
+ ineq = m.group(1)
+ rhs = m.group(2)
+ if not rhs.isdigit():
+ return False
+ rhs = int(rhs)
+
+ if ineq == '' or ineq == '==':
+ return room_member_count == rhs
+ elif ineq == '<':
+ return room_member_count < rhs
+ elif ineq == '>':
+ return room_member_count > rhs
+ elif ineq == '>=':
+ return room_member_count >= rhs
+ elif ineq == '<=':
+ return room_member_count <= rhs
+ else:
+ return False
+ else:
+ return True
+
+ @defer.inlineCallbacks
+ def get_context_for_event(self, ev):
+ name_aliases = yield self.store.get_room_name_and_aliases(
+ ev['room_id']
+ )
+
+ ctx = {'aliases': name_aliases[1]}
+ if name_aliases[0] is not None:
+ ctx['name'] = name_aliases[0]
+
+ their_member_events_for_room = yield self.store.get_current_state(
+ room_id=ev['room_id'],
+ event_type='m.room.member',
+ state_key=ev['user_id']
+ )
+ for mev in their_member_events_for_room:
+ if mev.content['membership'] == 'join' and 'displayname' in mev.content:
+ dn = mev.content['displayname']
+ if dn is not None:
+ ctx['sender_display_name'] = dn
+
+ defer.returnValue(ctx)
+
+ @defer.inlineCallbacks
+ def start(self):
+ if not self.last_token:
+ # First-time setup: get a token to start from (we can't
+ # just start from no token, ie. 'now'
+ # because we need the result to be reproduceable in case
+ # we fail to dispatch the push)
+ config = PaginationConfig(from_token=None, limit='1')
+ chunk = yield self.evStreamHandler.get_stream(
+ self.user_name, config, timeout=0)
+ self.last_token = chunk['end']
+ self.store.update_pusher_last_token(
+ self.user_name, self.pushkey, self.last_token)
+ logger.info("Pusher %s for user %s starting from token %s",
+ self.pushkey, self.user_name, self.last_token)
+
+ while self.alive:
+ from_tok = StreamToken.from_string(self.last_token)
+ config = PaginationConfig(from_token=from_tok, limit='1')
+ chunk = yield self.evStreamHandler.get_stream(
+ self.user_name, config,
+ timeout=100*365*24*60*60*1000, affect_presence=False
+ )
+
+ # limiting to 1 may get 1 event plus 1 presence event, so
+ # pick out the actual event
+ single_event = None
+ for c in chunk['chunk']:
+ if 'event_id' in c: # Hmmm...
+ single_event = c
+ break
+ if not single_event:
+ self.last_token = chunk['end']
+ continue
+
+ if not self.alive:
+ continue
+
+ processed = False
+ actions = yield self._actions_for_event(single_event)
+ tweaks = _tweaks_for_actions(actions)
+
+ if len(actions) == 0:
+ logger.warn("Empty actions! Using default action.")
+ actions = Pusher.DEFAULT_ACTIONS
+ if 'notify' not in actions and 'dont_notify' not in actions:
+ logger.warn("Neither notify nor dont_notify in actions: adding default")
+ actions.extend(Pusher.DEFAULT_ACTIONS)
+ if 'dont_notify' in actions:
+ logger.debug(
+ "%s for %s: dont_notify",
+ single_event['event_id'], self.user_name
+ )
+ processed = True
+ else:
+ rejected = yield self.dispatch_push(single_event, tweaks)
+ self.has_unread = True
+ if isinstance(rejected, list) or isinstance(rejected, tuple):
+ processed = True
+ for pk in rejected:
+ if pk != self.pushkey:
+ # for sanity, we only remove the pushkey if it
+ # was the one we actually sent...
+ logger.warn(
+ ("Ignoring rejected pushkey %s because we"
+ " didn't send it"), pk
+ )
+ else:
+ logger.info(
+ "Pushkey %s was rejected: removing",
+ pk
+ )
+ yield self.hs.get_pusherpool().remove_pusher(
+ self.app_id, pk
+ )
+
+ if not self.alive:
+ continue
+
+ if processed:
+ self.backoff_delay = Pusher.INITIAL_BACKOFF
+ self.last_token = chunk['end']
+ self.store.update_pusher_last_token_and_success(
+ self.user_name,
+ self.pushkey,
+ self.last_token,
+ self.clock.time_msec()
+ )
+ if self.failing_since:
+ self.failing_since = None
+ self.store.update_pusher_failing_since(
+ self.user_name,
+ self.pushkey,
+ self.failing_since)
+ else:
+ if not self.failing_since:
+ self.failing_since = self.clock.time_msec()
+ self.store.update_pusher_failing_since(
+ self.user_name,
+ self.pushkey,
+ self.failing_since
+ )
+
+ if (self.failing_since and
+ self.failing_since <
+ self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
+ # we really only give up so that if the URL gets
+ # fixed, we don't suddenly deliver a load
+ # of old notifications.
+ logger.warn("Giving up on a notification to user %s, "
+ "pushkey %s",
+ self.user_name, self.pushkey)
+ self.backoff_delay = Pusher.INITIAL_BACKOFF
+ self.last_token = chunk['end']
+ self.store.update_pusher_last_token(
+ self.user_name,
+ self.pushkey,
+ self.last_token
+ )
+
+ self.failing_since = None
+ self.store.update_pusher_failing_since(
+ self.user_name,
+ self.pushkey,
+ self.failing_since
+ )
+ else:
+ logger.warn("Failed to dispatch push for user %s "
+ "(failing for %dms)."
+ "Trying again in %dms",
+ self.user_name,
+ self.clock.time_msec() - self.failing_since,
+ self.backoff_delay)
+ yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
+ self.backoff_delay *= 2
+ if self.backoff_delay > Pusher.MAX_BACKOFF:
+ self.backoff_delay = Pusher.MAX_BACKOFF
+
+ def stop(self):
+ self.alive = False
+
+ def dispatch_push(self, p, tweaks):
+ """
+ Overridden by implementing classes to actually deliver the notification
+ Args:
+ p: The event to notify for as a single event from the event stream
+ Returns: If the notification was delivered, an array containing any
+ pushkeys that were rejected by the push gateway.
+ False if the notification could not be delivered (ie.
+ should be retried).
+ """
+ pass
+
+ def reset_badge_count(self):
+ pass
+
+ def presence_changed(self, state):
+ """
+ We clear badge counts whenever a user's last_active time is bumped
+ This is by no means perfect but I think it's the best we can do
+ without read receipts.
+ """
+ if 'last_active' in state.state:
+ last_active = state.state['last_active']
+ if last_active > self.last_last_active_time:
+ self.last_last_active_time = last_active
+ if self.has_unread:
+ logger.info("Resetting badge count for %s", self.user_name)
+ self.reset_badge_count()
+ self.has_unread = False
+
+
+def _value_for_dotted_key(dotted_key, event):
+ parts = dotted_key.split(".")
+ val = event
+ while len(parts) > 0:
+ if parts[0] not in val:
+ return None
+ val = val[parts[0]]
+ parts = parts[1:]
+ return val
+
+
+def _tweaks_for_actions(actions):
+ tweaks = {}
+ for a in actions:
+ if not isinstance(a, dict):
+ continue
+ if 'set_sound' in a:
+ tweaks['sound'] = a['set_sound']
+ return tweaks
+
+
+class PusherConfigException(Exception):
+ def __init__(self, msg):
+ super(PusherConfigException, self).__init__(msg)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
new file mode 100644
index 0000000000..382de118e0
--- /dev/null
+++ b/synapse/push/baserules.py
@@ -0,0 +1,48 @@
+def make_base_rules(user_name):
+ rules = [
+ {
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'content.body',
+ 'pattern': '*%s*' % (user_name,), # Matrix ID match
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_sound': 'default'
+ }
+ ]
+ },
+ {
+ 'conditions': [
+ {
+ 'kind': 'contains_display_name'
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_sound': 'default'
+ }
+ ]
+ },
+ {
+ 'conditions': [
+ {
+ 'kind': 'room_member_count',
+ 'is': '2'
+ }
+ ],
+ 'actions': [
+ 'notify',
+ {
+ 'set_sound': 'default'
+ }
+ ]
+ }
+ ]
+ for r in rules:
+ r['priority_class'] = 0
+ return rules
\ No newline at end of file
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
new file mode 100644
index 0000000000..7c6953c989
--- /dev/null
+++ b/synapse/push/httppusher.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.push import Pusher, PusherConfigException
+from synapse.http.client import SimpleHttpClient
+
+from twisted.internet import defer
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class HttpPusher(Pusher):
+ def __init__(self, _hs, instance_handle, user_name, app_id,
+ app_display_name, device_display_name, pushkey, pushkey_ts,
+ data, last_token, last_success, failing_since):
+ super(HttpPusher, self).__init__(
+ _hs,
+ instance_handle,
+ user_name,
+ app_id,
+ app_display_name,
+ device_display_name,
+ pushkey,
+ pushkey_ts,
+ data,
+ last_token,
+ last_success,
+ failing_since
+ )
+ if 'url' not in data:
+ raise PusherConfigException(
+ "'url' required in data for HTTP pusher"
+ )
+ self.url = data['url']
+ self.httpCli = SimpleHttpClient(self.hs)
+ self.data_minus_url = {}
+ self.data_minus_url.update(self.data)
+ del self.data_minus_url['url']
+
+ @defer.inlineCallbacks
+ def _build_notification_dict(self, event, tweaks):
+ # we probably do not want to push for every presence update
+ # (we may want to be able to set up notifications when specific
+ # people sign in, but we'd want to only deliver the pertinent ones)
+ # Actually, presence events will not get this far now because we
+ # need to filter them out in the main Pusher code.
+ if 'event_id' not in event:
+ defer.returnValue(None)
+
+ ctx = yield self.get_context_for_event(event)
+
+ d = {
+ 'notification': {
+ 'id': event['event_id'],
+ 'type': event['type'],
+ 'sender': event['user_id'],
+ 'counts': { # -- we don't mark messages as read yet so
+ # we have no way of knowing
+ # Just set the badge to 1 until we have read receipts
+ 'unread': 1,
+ # 'missed_calls': 2
+ },
+ 'devices': [
+ {
+ 'app_id': self.app_id,
+ 'pushkey': self.pushkey,
+ 'pushkey_ts': long(self.pushkey_ts / 1000),
+ 'data': self.data_minus_url,
+ 'tweaks': tweaks
+ }
+ ]
+ }
+ }
+ if event['type'] == 'm.room.member':
+ d['notification']['membership'] = event['content']['membership']
+ if 'content' in event:
+ d['notification']['content'] = event['content']
+
+ if len(ctx['aliases']):
+ d['notification']['room_alias'] = ctx['aliases'][0]
+ if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
+ d['notification']['sender_display_name'] = ctx['sender_display_name']
+ if 'name' in ctx and len(ctx['name']) > 0:
+ d['notification']['room_name'] = ctx['name']
+
+ defer.returnValue(d)
+
+ @defer.inlineCallbacks
+ def dispatch_push(self, event, tweaks):
+ notification_dict = yield self._build_notification_dict(event, tweaks)
+ if not notification_dict:
+ defer.returnValue([])
+ try:
+ resp = yield self.httpCli.post_json_get_json(self.url, notification_dict)
+ except:
+ logger.exception("Failed to push %s ", self.url)
+ defer.returnValue(False)
+ rejected = []
+ if 'rejected' in resp:
+ rejected = resp['rejected']
+ defer.returnValue(rejected)
+
+ @defer.inlineCallbacks
+ def reset_badge_count(self):
+ d = {
+ 'notification': {
+ 'id': '',
+ 'type': None,
+ 'sender': '',
+ 'counts': {
+ 'unread': 0,
+ 'missed_calls': 0
+ },
+ 'devices': [
+ {
+ 'app_id': self.app_id,
+ 'pushkey': self.pushkey,
+ 'pushkey_ts': long(self.pushkey_ts / 1000),
+ 'data': self.data_minus_url,
+ }
+ ]
+ }
+ }
+ try:
+ resp = yield self.httpCli.post_json_get_json(self.url, d)
+ except:
+ logger.exception("Failed to push %s ", self.url)
+ defer.returnValue(False)
+ rejected = []
+ if 'rejected' in resp:
+ rejected = resp['rejected']
+ defer.returnValue(rejected)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
new file mode 100644
index 0000000000..4892c21e7b
--- /dev/null
+++ b/synapse/push/pusherpool.py
@@ -0,0 +1,152 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from httppusher import HttpPusher
+from synapse.push import PusherConfigException
+
+import logging
+import json
+
+logger = logging.getLogger(__name__)
+
+
+class PusherPool:
+ def __init__(self, _hs):
+ self.hs = _hs
+ self.store = self.hs.get_datastore()
+ self.pushers = {}
+ self.last_pusher_started = -1
+
+ distributor = self.hs.get_distributor()
+ distributor.observe(
+ "user_presence_changed", self.user_presence_changed
+ )
+
+ @defer.inlineCallbacks
+ def user_presence_changed(self, user, state):
+ user_name = user.to_string()
+
+ # until we have read receipts, pushers use this to reset a user's
+ # badge counters to zero
+ for p in self.pushers.values():
+ if p.user_name == user_name:
+ yield p.presence_changed(state)
+
+ @defer.inlineCallbacks
+ def start(self):
+ pushers = yield self.store.get_all_pushers()
+ for p in pushers:
+ p['data'] = json.loads(p['data'])
+ self._start_pushers(pushers)
+
+ @defer.inlineCallbacks
+ def add_pusher(self, user_name, instance_handle, kind, app_id,
+ app_display_name, device_display_name, pushkey, lang, data):
+ # we try to create the pusher just to validate the config: it
+ # will then get pulled out of the database,
+ # recreated, added and started: this means we have only one
+ # code path adding pushers.
+ self._create_pusher({
+ "user_name": user_name,
+ "kind": kind,
+ "instance_handle": instance_handle,
+ "app_id": app_id,
+ "app_display_name": app_display_name,
+ "device_display_name": device_display_name,
+ "pushkey": pushkey,
+ "pushkey_ts": self.hs.get_clock().time_msec(),
+ "lang": lang,
+ "data": data,
+ "last_token": None,
+ "last_success": None,
+ "failing_since": None
+ })
+ yield self._add_pusher_to_store(
+ user_name, instance_handle, kind, app_id,
+ app_display_name, device_display_name,
+ pushkey, lang, data
+ )
+
+ @defer.inlineCallbacks
+ def _add_pusher_to_store(self, user_name, instance_handle, kind, app_id,
+ app_display_name, device_display_name,
+ pushkey, lang, data):
+ yield self.store.add_pusher(
+ user_name=user_name,
+ instance_handle=instance_handle,
+ kind=kind,
+ app_id=app_id,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ pushkey=pushkey,
+ pushkey_ts=self.hs.get_clock().time_msec(),
+ lang=lang,
+ data=json.dumps(data)
+ )
+ self._refresh_pusher((app_id, pushkey))
+
+ def _create_pusher(self, pusherdict):
+ if pusherdict['kind'] == 'http':
+ return HttpPusher(
+ self.hs,
+ instance_handle=pusherdict['instance_handle'],
+ user_name=pusherdict['user_name'],
+ app_id=pusherdict['app_id'],
+ app_display_name=pusherdict['app_display_name'],
+ device_display_name=pusherdict['device_display_name'],
+ pushkey=pusherdict['pushkey'],
+ pushkey_ts=pusherdict['pushkey_ts'],
+ data=pusherdict['data'],
+ last_token=pusherdict['last_token'],
+ last_success=pusherdict['last_success'],
+ failing_since=pusherdict['failing_since']
+ )
+ else:
+ raise PusherConfigException(
+ "Unknown pusher type '%s' for user %s" %
+ (pusherdict['kind'], pusherdict['user_name'])
+ )
+
+ @defer.inlineCallbacks
+ def _refresh_pusher(self, app_id_pushkey):
+ p = yield self.store.get_pushers_by_app_id_and_pushkey(
+ app_id_pushkey
+ )
+ p['data'] = json.loads(p['data'])
+
+ self._start_pushers([p])
+
+ def _start_pushers(self, pushers):
+ logger.info("Starting %d pushers", len(pushers))
+ for pusherdict in pushers:
+ p = self._create_pusher(pusherdict)
+ if p:
+ fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey'])
+ if fullid in self.pushers:
+ self.pushers[fullid].stop()
+ self.pushers[fullid] = p
+ p.start()
+
+ @defer.inlineCallbacks
+ def remove_pusher(self, app_id, pushkey):
+ fullid = "%s:%s" % (app_id, pushkey)
+ if fullid in self.pushers:
+ logger.info("Stopping pusher %s", fullid)
+ self.pushers[fullid].stop()
+ del self.pushers[fullid]
+ yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 4182ad990f..826a36f203 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -6,7 +6,7 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil==0.0.2": ["syutil"],
"matrix_angular_sdk==0.6.0": ["syweb>=0.6.0"],
- "Twisted>=14.0.0": ["twisted>=14.0.0"],
+ "Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"],
diff --git a/synapse/rest/client/v1/__init__.py b/synapse/rest/client/v1/__init__.py
index 8bb89b2f6a..d8d01cdd16 100644
--- a/synapse/rest/client/v1/__init__.py
+++ b/synapse/rest/client/v1/__init__.py
@@ -13,10 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
from . import (
room, events, register, login, profile, presence, initial_sync, directory,
- voip, admin,
+ voip, admin, pusher, push_rule
)
from synapse.http.server import JsonResource
@@ -41,3 +40,5 @@ class ClientV1RestResource(JsonResource):
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
admin.register_servlets(hs, client_resource)
+ pusher.register_servlets(hs, client_resource)
+ push_rule.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 1051d96f96..2ce754b028 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user:
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 15ae8749b8..8f65efec5f 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -45,7 +45,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
- user = yield self.auth.get_user_by_req(request)
+ user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
if not "room_id" in content:
@@ -85,7 +85,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request, room_alias):
- user = yield self.auth.get_user_by_req(request)
+ user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index a0d051227b..77b7c25a03 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
try:
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
@@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, event_id):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id)
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 357fa845b4..4a259bba64 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- user = yield self.auth.get_user_by_req(request)
+ user, client = yield self.auth.get_user_by_req(request)
with_feedback = "feedback" in request.args
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index b6c207e662..7feb4aadb1 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = yield self.handlers.presence_handler.get_state(
@@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
state = {}
@@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
@@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request, user_id):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if not self.hs.is_mine(user):
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 24f8d56952..15d6f3fc6c 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
try:
@@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
try:
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
new file mode 100644
index 0000000000..faa7919fbb
--- /dev/null
+++ b/synapse/rest/client/v1/push_rule.py
@@ -0,0 +1,406 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError, NotFoundError, \
+ StoreError
+from .base import ClientV1RestServlet, client_path_pattern
+from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+import synapse.push.baserules as baserules
+
+import json
+
+
+class PushRuleRestServlet(ClientV1RestServlet):
+ PATTERN = client_path_pattern("/pushrules/.*$")
+ PRIORITY_CLASS_MAP = {
+ 'default': 0,
+ 'underride': 1,
+ 'sender': 2,
+ 'room': 3,
+ 'content': 4,
+ 'override': 5,
+ }
+ PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()}
+ SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
+ "Unrecognised request: You probably wanted a trailing slash")
+
+ def rule_spec_from_path(self, path):
+ if len(path) < 2:
+ raise UnrecognizedRequestError()
+ if path[0] != 'pushrules':
+ raise UnrecognizedRequestError()
+
+ scope = path[1]
+ path = path[2:]
+ if scope not in ['global', 'device']:
+ raise UnrecognizedRequestError()
+
+ device = None
+ if scope == 'device':
+ if len(path) == 0:
+ raise UnrecognizedRequestError()
+ device = path[0]
+ path = path[1:]
+
+ if len(path) == 0:
+ raise UnrecognizedRequestError()
+
+ template = path[0]
+ path = path[1:]
+
+ if len(path) == 0:
+ raise UnrecognizedRequestError()
+
+ rule_id = path[0]
+
+ spec = {
+ 'scope': scope,
+ 'template': template,
+ 'rule_id': rule_id
+ }
+ if device:
+ spec['device'] = device
+ return spec
+
+ def rule_tuple_from_request_object(self, rule_template, rule_id, req_obj, device=None):
+ if rule_template in ['override', 'underride']:
+ if 'conditions' not in req_obj:
+ raise InvalidRuleException("Missing 'conditions'")
+ conditions = req_obj['conditions']
+ for c in conditions:
+ if 'kind' not in c:
+ raise InvalidRuleException("Condition without 'kind'")
+ elif rule_template == 'room':
+ conditions = [{
+ 'kind': 'event_match',
+ 'key': 'room_id',
+ 'pattern': rule_id
+ }]
+ elif rule_template == 'sender':
+ conditions = [{
+ 'kind': 'event_match',
+ 'key': 'user_id',
+ 'pattern': rule_id
+ }]
+ elif rule_template == 'content':
+ if 'pattern' not in req_obj:
+ raise InvalidRuleException("Content rule missing 'pattern'")
+ pat = req_obj['pattern']
+
+ conditions = [{
+ 'kind': 'event_match',
+ 'key': 'content.body',
+ 'pattern': pat
+ }]
+ else:
+ raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
+
+ if device:
+ conditions.append({
+ 'kind': 'device',
+ 'instance_handle': device
+ })
+
+ if 'actions' not in req_obj:
+ raise InvalidRuleException("No actions found")
+ actions = req_obj['actions']
+
+ for a in actions:
+ if a in ['notify', 'dont_notify', 'coalesce']:
+ pass
+ elif isinstance(a, dict) and 'set_sound' in a:
+ pass
+ else:
+ raise InvalidRuleException("Unrecognised action")
+
+ return conditions, actions
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request):
+ spec = self.rule_spec_from_path(request.postpath)
+ try:
+ priority_class = _priority_class_from_spec(spec)
+ except InvalidRuleException as e:
+ raise SynapseError(400, e.message)
+
+ user, _ = yield self.auth.get_user_by_req(request)
+
+ if spec['template'] == 'default':
+ raise SynapseError(403, "The default rules are immutable.")
+
+ content = _parse_json(request)
+
+ try:
+ (conditions, actions) = self.rule_tuple_from_request_object(
+ spec['template'],
+ spec['rule_id'],
+ content,
+ device=spec['device'] if 'device' in spec else None
+ )
+ except InvalidRuleException as e:
+ raise SynapseError(400, e.message)
+
+ before = request.args.get("before", None)
+ if before and len(before):
+ before = before[0]
+ after = request.args.get("after", None)
+ if after and len(after):
+ after = after[0]
+
+ try:
+ yield self.hs.get_datastore().add_push_rule(
+ user_name=user.to_string(),
+ rule_id=spec['rule_id'],
+ priority_class=priority_class,
+ conditions=conditions,
+ actions=actions,
+ before=before,
+ after=after
+ )
+ except InconsistentRuleException as e:
+ raise SynapseError(400, e.message)
+ except RuleNotFoundException as e:
+ raise SynapseError(400, e.message)
+
+ defer.returnValue((200, {}))
+
+ @defer.inlineCallbacks
+ def on_DELETE(self, request):
+ spec = self.rule_spec_from_path(request.postpath)
+ try:
+ priority_class = _priority_class_from_spec(spec)
+ except InvalidRuleException as e:
+ raise SynapseError(400, e.message)
+
+ user, _ = yield self.auth.get_user_by_req(request)
+
+ if 'device' in spec:
+ rules = yield self.hs.get_datastore().get_push_rules_for_user_name(
+ user.to_string()
+ )
+
+ for r in rules:
+ conditions = json.loads(r['conditions'])
+ ih = _instance_handle_from_conditions(conditions)
+ if ih == spec['device'] and r['priority_class'] == priority_class:
+ yield self.hs.get_datastore().delete_push_rule(
+ user.to_string(), spec['rule_id']
+ )
+ defer.returnValue((200, {}))
+ raise NotFoundError()
+ else:
+ try:
+ yield self.hs.get_datastore().delete_push_rule(
+ user.to_string(), spec['rule_id'],
+ priority_class=priority_class
+ )
+ defer.returnValue((200, {}))
+ except StoreError as e:
+ if e.code == 404:
+ raise NotFoundError()
+ else:
+ raise
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ user, _ = yield self.auth.get_user_by_req(request)
+
+ # we build up the full structure and then decide which bits of it
+ # to send which means doing unnecessary work sometimes but is
+ # is probably not going to make a whole lot of difference
+ rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name(user.to_string())
+ for r in rawrules:
+ r["conditions"] = json.loads(r["conditions"])
+ r["actions"] = json.loads(r["actions"])
+ rawrules.extend(baserules.make_base_rules(user.to_string()))
+
+ rules = {'global': {}, 'device': {}}
+
+ rules['global'] = _add_empty_priority_class_arrays(rules['global'])
+
+ for r in rawrules:
+ rulearray = None
+
+ template_name = _priority_class_to_template_name(r['priority_class'])
+
+ if r['priority_class'] > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
+ # per-device rule
+ instance_handle = _instance_handle_from_conditions(r["conditions"])
+ r = _strip_device_condition(r)
+ if not instance_handle:
+ continue
+ if instance_handle not in rules['device']:
+ rules['device'][instance_handle] = {}
+ rules['device'][instance_handle] = (
+ _add_empty_priority_class_arrays(
+ rules['device'][instance_handle]
+ )
+ )
+
+ rulearray = rules['device'][instance_handle][template_name]
+ else:
+ rulearray = rules['global'][template_name]
+
+ template_rule = _rule_to_template(r)
+ if template_rule:
+ rulearray.append(template_rule)
+
+ path = request.postpath[1:]
+
+ if path == []:
+ # we're a reference impl: pedantry is our job.
+ raise UnrecognizedRequestError(
+ PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
+ )
+
+ if path[0] == '':
+ defer.returnValue((200, rules))
+ elif path[0] == 'global':
+ path = path[1:]
+ result = _filter_ruleset_with_path(rules['global'], path)
+ defer.returnValue((200, result))
+ elif path[0] == 'device':
+ path = path[1:]
+ if path == []:
+ raise UnrecognizedRequestError(
+ PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
+ )
+ if path[0] == '':
+ defer.returnValue((200, rules['device']))
+
+ instance_handle = path[0]
+ path = path[1:]
+ if instance_handle not in rules['device']:
+ ret = {}
+ ret = _add_empty_priority_class_arrays(ret)
+ defer.returnValue((200, ret))
+ ruleset = rules['device'][instance_handle]
+ result = _filter_ruleset_with_path(ruleset, path)
+ defer.returnValue((200, result))
+ else:
+ raise UnrecognizedRequestError()
+
+ def on_OPTIONS(self, _):
+ return 200, {}
+
+
+def _add_empty_priority_class_arrays(d):
+ for pc in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys():
+ d[pc] = []
+ return d
+
+
+def _instance_handle_from_conditions(conditions):
+ """
+ Given a list of conditions, return the instance handle of the
+ device rule if there is one
+ """
+ for c in conditions:
+ if c['kind'] == 'device':
+ return c['instance_handle']
+ return None
+
+
+def _filter_ruleset_with_path(ruleset, path):
+ if path == []:
+ raise UnrecognizedRequestError(
+ PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
+ )
+
+ if path[0] == '':
+ return ruleset
+ template_kind = path[0]
+ if template_kind not in ruleset:
+ raise UnrecognizedRequestError()
+ path = path[1:]
+ if path == []:
+ raise UnrecognizedRequestError(
+ PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
+ )
+ if path[0] == '':
+ return ruleset[template_kind]
+ rule_id = path[0]
+ for r in ruleset[template_kind]:
+ if r['rule_id'] == rule_id:
+ return r
+ raise NotFoundError
+
+
+def _priority_class_from_spec(spec):
+ if spec['template'] not in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys():
+ raise InvalidRuleException("Unknown template: %s" % (spec['kind']))
+ pc = PushRuleRestServlet.PRIORITY_CLASS_MAP[spec['template']]
+
+ if spec['scope'] == 'device':
+ pc += len(PushRuleRestServlet.PRIORITY_CLASS_MAP)
+
+ return pc
+
+
+def _priority_class_to_template_name(pc):
+ if pc > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']:
+ # per-device
+ prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP)
+ return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
+ else:
+ return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[pc]
+
+
+def _rule_to_template(rule):
+ template_name = _priority_class_to_template_name(rule['priority_class'])
+ if template_name in ['default']:
+ return {k: rule[k] for k in ["conditions", "actions"]}
+ elif template_name in ['override', 'underride']:
+ return {k: rule[k] for k in ["rule_id", "conditions", "actions"]}
+ elif template_name in ["sender", "room"]:
+ return {k: rule[k] for k in ["rule_id", "actions"]}
+ elif template_name == 'content':
+ if len(rule["conditions"]) != 1:
+ return None
+ thecond = rule["conditions"][0]
+ if "pattern" not in thecond:
+ return None
+ ret = {k: rule[k] for k in ["rule_id", "actions"]}
+ ret["pattern"] = thecond["pattern"]
+ return ret
+
+
+def _strip_device_condition(rule):
+ for i, c in enumerate(rule['conditions']):
+ if c['kind'] == 'device':
+ del rule['conditions'][i]
+ return rule
+
+
+class InvalidRuleException(Exception):
+ pass
+
+
+# XXX: C+ped from rest/room.py - surely this should be common?
+def _parse_json(request):
+ try:
+ content = json.loads(request.content.read())
+ if type(content) != dict:
+ raise SynapseError(400, "Content must be a JSON object.",
+ errcode=Codes.NOT_JSON)
+ return content
+ except ValueError:
+ raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+
+
+def register_servlets(hs, http_server):
+ PushRuleRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
new file mode 100644
index 0000000000..353a4a6589
--- /dev/null
+++ b/synapse/rest/client/v1/pusher.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError, Codes
+from synapse.push import PusherConfigException
+from .base import ClientV1RestServlet, client_path_pattern
+
+import json
+
+
+class PusherRestServlet(ClientV1RestServlet):
+ PATTERN = client_path_pattern("/pushers/set$")
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ user, _ = yield self.auth.get_user_by_req(request)
+
+ content = _parse_json(request)
+
+ pusher_pool = self.hs.get_pusherpool()
+
+ if ('pushkey' in content and 'app_id' in content
+ and 'kind' in content and
+ content['kind'] is None):
+ yield pusher_pool.remove_pusher(
+ content['app_id'], content['pushkey']
+ )
+ defer.returnValue((200, {}))
+
+ reqd = ['instance_handle', 'kind', 'app_id', 'app_display_name',
+ 'device_display_name', 'pushkey', 'lang', 'data']
+ missing = []
+ for i in reqd:
+ if i not in content:
+ missing.append(i)
+ if len(missing):
+ raise SynapseError(400, "Missing parameters: "+','.join(missing),
+ errcode=Codes.MISSING_PARAM)
+
+ try:
+ yield pusher_pool.add_pusher(
+ user_name=user.to_string(),
+ instance_handle=content['instance_handle'],
+ kind=content['kind'],
+ app_id=content['app_id'],
+ app_display_name=content['app_display_name'],
+ device_display_name=content['device_display_name'],
+ pushkey=content['pushkey'],
+ lang=content['lang'],
+ data=content['data']
+ )
+ except PusherConfigException as pce:
+ raise SynapseError(400, "Config Error: "+pce.message,
+ errcode=Codes.MISSING_PARAM)
+
+ defer.returnValue((200, {}))
+
+ def on_OPTIONS(self, _):
+ return 200, {}
+
+
+# XXX: C+ped from rest/room.py - surely this should be common?
+def _parse_json(request):
+ try:
+ content = json.loads(request.content.read())
+ if type(content) != dict:
+ raise SynapseError(400, "Content must be a JSON object.",
+ errcode=Codes.NOT_JSON)
+ return content
+ except ValueError:
+ raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+
+
+def register_servlets(hs, http_server):
+ PusherRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 58b09b6fc1..410f19ccf6 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
room_config = self.get_room_config(request)
info = yield self.make_room(room_config, auth_user, None)
@@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key):
- user = yield self.auth.get_user_by_req(request)
+ user, client = yield self.auth.get_user_by_req(request)
msg_handler = self.handlers.message_handler
data = yield msg_handler.get_room_data(
@@ -142,8 +142,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
defer.returnValue((200, data.get_dict()["content"]))
@defer.inlineCallbacks
- def on_PUT(self, request, room_id, event_type, state_key):
- user = yield self.auth.get_user_by_req(request)
+ def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+ user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@@ -158,7 +158,9 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler
- yield msg_handler.create_and_send_event(event_dict)
+ yield msg_handler.create_and_send_event(
+ event_dict, client=client, txn_id=txn_id,
+ )
defer.returnValue((200, {}))
@@ -172,8 +174,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server, with_get=True)
@defer.inlineCallbacks
- def on_POST(self, request, room_id, event_type):
- user = yield self.auth.get_user_by_req(request)
+ def on_POST(self, request, room_id, event_type, txn_id=None):
+ user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@@ -183,7 +185,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
"content": content,
"room_id": room_id,
"sender": user.to_string(),
- }
+ },
+ client=client,
+ txn_id=txn_id,
)
defer.returnValue((200, {"event_id": event.event_id}))
@@ -200,7 +204,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
except KeyError:
pass
- response = yield self.on_POST(request, room_id, event_type)
+ response = yield self.on_POST(request, room_id, event_type, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@@ -215,8 +219,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
- def on_POST(self, request, room_identifier):
- user = yield self.auth.get_user_by_req(request)
+ def on_POST(self, request, room_identifier, txn_id=None):
+ user, client = yield self.auth.get_user_by_req(request)
# the identifier could be a room alias or a room id. Try one then the
# other if it fails to parse, without swallowing other valid
@@ -245,7 +249,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
"room_id": identifier.to_string(),
"sender": user.to_string(),
"state_key": user.to_string(),
- }
+ },
+ client=client,
+ txn_id=txn_id,
)
defer.returnValue((200, {"room_id": identifier.to_string()}))
@@ -259,7 +265,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
except KeyError:
pass
- response = yield self.on_POST(request, room_identifier)
+ response = yield self.on_POST(request, room_identifier, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@@ -283,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
- user = yield self.auth.get_user_by_req(request)
+ user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.room_member_handler
members = yield handler.get_room_members_as_pagination_chunk(
room_id=room_id,
@@ -311,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- user = yield self.auth.get_user_by_req(request)
+ user, client = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(
request, default_limit=10,
)
@@ -335,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- user = yield self.auth.get_user_by_req(request)
+ user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.message_handler
# Get all the current state for this room
events = yield handler.get_state_events(
@@ -351,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- user = yield self.auth.get_user_by_req(request)
+ user, client = yield self.auth.get_user_by_req(request)
pagination_config = PaginationConfig.from_request(request)
content = yield self.handlers.message_handler.room_initial_sync(
room_id=room_id,
@@ -395,8 +401,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
- def on_POST(self, request, room_id, membership_action):
- user = yield self.auth.get_user_by_req(request)
+ def on_POST(self, request, room_id, membership_action, txn_id=None):
+ user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
@@ -418,7 +424,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
"room_id": room_id,
"sender": user.to_string(),
"state_key": state_key,
- }
+ },
+ client=client,
+ txn_id=txn_id,
)
defer.returnValue((200, {}))
@@ -432,7 +440,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
except KeyError:
pass
- response = yield self.on_POST(request, room_id, membership_action)
+ response = yield self.on_POST(
+ request, room_id, membership_action, txn_id
+ )
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@@ -444,8 +454,8 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
register_txn_path(self, PATTERN, http_server)
@defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id):
- user = yield self.auth.get_user_by_req(request)
+ def on_POST(self, request, room_id, event_id, txn_id=None):
+ user, client = yield self.auth.get_user_by_req(request)
content = _parse_json(request)
msg_handler = self.handlers.message_handler
@@ -456,7 +466,9 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
"room_id": room_id,
"sender": user.to_string(),
"redacts": event_id,
- }
+ },
+ client=client,
+ txn_id=txn_id,
)
defer.returnValue((200, {"event_id": event.event_id}))
@@ -470,7 +482,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
except KeyError:
pass
- response = yield self.on_POST(request, room_id, event_id)
+ response = yield self.on_POST(request, room_id, event_id, txn_id)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
@@ -483,7 +495,7 @@ class RoomTypingRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
room_id = urllib.unquote(room_id)
target_user = UserID.from_string(urllib.unquote(user_id))
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 822d863ce6..11d08fbced 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py
index bb740e2803..8f611de3a8 100644
--- a/synapse/rest/client/v2_alpha/__init__.py
+++ b/synapse/rest/client/v2_alpha/__init__.py
@@ -13,6 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from . import (
+ sync,
+ filter
+)
from synapse.http.server import JsonResource
@@ -26,4 +30,5 @@ class ClientV2AlphaRestResource(JsonResource):
@staticmethod
def register_servlets(client_resource, hs):
- pass
+ sync.register_servlets(hs, client_resource)
+ filter.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
new file mode 100644
index 0000000000..6ddc495d23
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -0,0 +1,104 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import AuthError, SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.types import UserID
+
+from ._base import client_v2_pattern
+
+import json
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class GetFilterRestServlet(RestServlet):
+ PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
+
+ def __init__(self, hs):
+ super(GetFilterRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.filtering = hs.get_filtering()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id, filter_id):
+ target_user = UserID.from_string(user_id)
+ auth_user, client = yield self.auth.get_user_by_req(request)
+
+ if target_user != auth_user:
+ raise AuthError(403, "Cannot get filters for other users")
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only get filters for local users")
+
+ try:
+ filter_id = int(filter_id)
+ except:
+ raise SynapseError(400, "Invalid filter_id")
+
+ try:
+ filter = yield self.filtering.get_user_filter(
+ user_localpart=target_user.localpart,
+ filter_id=filter_id,
+ )
+
+ defer.returnValue((200, filter.filter_json))
+ except KeyError:
+ raise SynapseError(400, "No such filter")
+
+
+class CreateFilterRestServlet(RestServlet):
+ PATTERN = client_v2_pattern("/user/(?P<user_id>[^/]*)/filter")
+
+ def __init__(self, hs):
+ super(CreateFilterRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.filtering = hs.get_filtering()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, user_id):
+ target_user = UserID.from_string(user_id)
+ auth_user, client = yield self.auth.get_user_by_req(request)
+
+ if target_user != auth_user:
+ raise AuthError(403, "Cannot create filters for other users")
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only create filters for local users")
+
+ try:
+ content = json.loads(request.content.read())
+
+ # TODO(paul): check for required keys and invalid keys
+ except:
+ raise SynapseError(400, "Invalid filter definition")
+
+ filter_id = yield self.filtering.add_user_filter(
+ user_localpart=target_user.localpart,
+ user_filter=content,
+ )
+
+ defer.returnValue((200, {"filter_id": str(filter_id)}))
+
+
+def register_servlets(hs, http_server):
+ GetFilterRestServlet(hs).register(http_server)
+ CreateFilterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
new file mode 100644
index 0000000000..81d5cf8ead
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -0,0 +1,207 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet
+from synapse.handlers.sync import SyncConfig
+from synapse.types import StreamToken
+from synapse.events.utils import (
+ serialize_event, format_event_for_client_v2_without_event_id,
+)
+from synapse.api.filtering import Filter
+from ._base import client_v2_pattern
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class SyncRestServlet(RestServlet):
+ """
+
+ GET parameters::
+ timeout(int): How long to wait for new events in milliseconds.
+ limit(int): Maxiumum number of events per room to return.
+ gap(bool): Create gaps the message history if limit is exceeded to
+ ensure that the client has the most recent messages. Defaults to
+ "true".
+ sort(str,str): tuple of sort key (e.g. "timeline") and direction
+ (e.g. "asc", "desc"). Defaults to "timeline,asc".
+ since(batch_token): Batch token when asking for incremental deltas.
+ set_presence(str): What state the device presence should be set to.
+ default is "online".
+ backfill(bool): Should the HS request message history from other
+ servers. This may take a long time making it unsuitable for clients
+ expecting a prompt response. Defaults to "true".
+ filter(filter_id): A filter to apply to the events returned.
+ filter_*: Filter override parameters.
+
+ Response JSON::
+ {
+ "next_batch": // batch token for the next /sync
+ "private_user_data": // private events for this user.
+ "public_user_data": // public events for all users including the
+ // public events for this user.
+ "rooms": [{ // List of rooms with updates.
+ "room_id": // Id of the room being updated
+ "limited": // Was the per-room event limit exceeded?
+ "published": // Is the room published by our HS?
+ "event_map": // Map of EventID -> event JSON.
+ "events": { // The recent events in the room if gap is "true"
+ // otherwise the next events in the room.
+ "batch": [] // list of EventIDs in the "event_map".
+ "prev_batch": // back token for getting previous events.
+ }
+ "state": [] // list of EventIDs updating the current state to
+ // be what it should be at the end of the batch.
+ "ephemeral": []
+ }]
+ }
+ """
+
+ PATTERN = client_v2_pattern("/sync$")
+ ALLOWED_SORT = set(["timeline,asc", "timeline,desc"])
+ ALLOWED_PRESENCE = set(["online", "offline", "idle"])
+
+ def __init__(self, hs):
+ super(SyncRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.sync_handler = hs.get_handlers().sync_handler
+ self.clock = hs.get_clock()
+ self.filtering = hs.get_filtering()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ user, client = yield self.auth.get_user_by_req(request)
+
+ timeout = self.parse_integer(request, "timeout", default=0)
+ limit = self.parse_integer(request, "limit", required=True)
+ gap = self.parse_boolean(request, "gap", default=True)
+ sort = self.parse_string(
+ request, "sort", default="timeline,asc",
+ allowed_values=self.ALLOWED_SORT
+ )
+ since = self.parse_string(request, "since")
+ set_presence = self.parse_string(
+ request, "set_presence", default="online",
+ allowed_values=self.ALLOWED_PRESENCE
+ )
+ backfill = self.parse_boolean(request, "backfill", default=False)
+ filter_id = self.parse_string(request, "filter", default=None)
+
+ logger.info(
+ "/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r,"
+ " set_presence=%r, backfill=%r, filter_id=%r" % (
+ user, timeout, limit, gap, sort, since, set_presence,
+ backfill, filter_id
+ )
+ )
+
+ # TODO(mjark): Load filter and apply overrides.
+ try:
+ filter = yield self.filtering.get_user_filter(
+ user.localpart, filter_id
+ )
+ except:
+ filter = Filter({})
+ # filter = filter.apply_overrides(http_request)
+ #if filter.matches(event):
+ # # stuff
+
+ sync_config = SyncConfig(
+ user=user,
+ client_info=client,
+ gap=gap,
+ limit=limit,
+ sort=sort,
+ backfill=backfill,
+ filter=filter,
+ )
+
+ if since is not None:
+ since_token = StreamToken.from_string(since)
+ else:
+ since_token = None
+
+ sync_result = yield self.sync_handler.wait_for_sync_for_user(
+ sync_config, since_token=since_token, timeout=timeout
+ )
+
+ time_now = self.clock.time_msec()
+
+ response_content = {
+ "public_user_data": self.encode_user_data(
+ sync_result.public_user_data, filter, time_now
+ ),
+ "private_user_data": self.encode_user_data(
+ sync_result.private_user_data, filter, time_now
+ ),
+ "rooms": self.encode_rooms(
+ sync_result.rooms, filter, time_now, client.token_id
+ ),
+ "next_batch": sync_result.next_batch.to_string(),
+ }
+
+ defer.returnValue((200, response_content))
+
+ def encode_user_data(self, events, filter, time_now):
+ return events
+
+ def encode_rooms(self, rooms, filter, time_now, token_id):
+ return [
+ self.encode_room(room, filter, time_now, token_id)
+ for room in rooms
+ ]
+
+ @staticmethod
+ def encode_room(room, filter, time_now, token_id):
+ event_map = {}
+ state_events = filter.filter_room_state(room.state)
+ recent_events = filter.filter_room_events(room.events)
+ state_event_ids = []
+ recent_event_ids = []
+ for event in state_events:
+ # TODO(mjark): Respect formatting requirements in the filter.
+ event_map[event.event_id] = serialize_event(
+ event, time_now, token_id=token_id,
+ event_format=format_event_for_client_v2_without_event_id,
+ )
+ state_event_ids.append(event.event_id)
+
+ for event in recent_events:
+ # TODO(mjark): Respect formatting requirements in the filter.
+ event_map[event.event_id] = serialize_event(
+ event, time_now, token_id=token_id,
+ event_format=format_event_for_client_v2_without_event_id,
+ )
+ recent_event_ids.append(event.event_id)
+ result = {
+ "room_id": room.room_id,
+ "event_map": event_map,
+ "events": {
+ "batch": recent_event_ids,
+ "prev_batch": room.prev_batch.to_string(),
+ },
+ "state": state_event_ids,
+ "limited": room.limited,
+ "published": room.published,
+ "ephemeral": room.ephemeral,
+ }
+ return result
+
+
+def register_servlets(hs, http_server):
+ SyncRestServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py
index 79ae0e3d74..22e26e3cd5 100644
--- a/synapse/rest/media/v0/content_repository.py
+++ b/synapse/rest/media/v0/content_repository.py
@@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource):
@defer.inlineCallbacks
def map_request_to_name(self, request):
# auth the user
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
# namespace all file uploads on the user
prefix = base64.urlsafe_b64encode(
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index b1718a630b..b939a30e19 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -42,7 +42,7 @@ class UploadResource(BaseMediaResource):
@defer.inlineCallbacks
def _async_render_POST(self, request):
try:
- auth_user = yield self.auth.get_user_by_req(request)
+ auth_user, client = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
content_length = request.getHeader("Content-Length")
diff --git a/synapse/server.py b/synapse/server.py
index 891c5aa13d..ba2b2593f1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -31,7 +31,9 @@ from synapse.util.lockutils import LockManager
from synapse.streams.events import EventSources
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.keyring import Keyring
+from synapse.push.pusherpool import PusherPool
from synapse.events.builder import EventBuilderFactory
+from synapse.api.filtering import Filtering
class BaseHomeServer(object):
@@ -79,7 +81,9 @@ class BaseHomeServer(object):
'event_sources',
'ratelimiter',
'keyring',
+ 'pusherpool',
'event_builder_factory',
+ 'filtering',
]
def __init__(self, hostname, **kwargs):
@@ -198,3 +202,9 @@ class HomeServer(BaseHomeServer):
clock=self.get_clock(),
hostname=self.hostname,
)
+
+ def build_filtering(self):
+ return Filtering(self)
+
+ def build_pusherpool(self):
+ return PusherPool(self)
diff --git a/synapse/state.py b/synapse/state.py
index 8144fa02b4..8a056ee955 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError
from synapse.events.snapshot import EventContext
from collections import namedtuple
@@ -36,12 +37,16 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
+AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,)
+
+
class StateHandler(object):
""" Responsible for doing state conflict resolution.
"""
def __init__(self, hs):
self.store = hs.get_datastore()
+ self.hs = hs
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
@@ -163,10 +168,17 @@ class StateHandler(object):
first is the name of a state group if one and only one is involved,
otherwise `None`.
"""
+ logger.debug("resolve_state_groups event_ids %s", event_ids)
+
state_groups = yield self.store.get_state_groups(
event_ids
)
+ logger.debug(
+ "resolve_state_groups state_groups %s",
+ state_groups.keys()
+ )
+
group_names = set(state_groups.keys())
if len(group_names) == 1:
name, state_list = state_groups.items().pop()
@@ -210,64 +222,93 @@ class StateHandler(object):
else:
prev_states = []
+ auth_events = {
+ k: e for k, e in unconflicted_state.items()
+ if k[0] in AuthEventTypes
+ }
+
try:
- new_state = {}
- new_state.update(unconflicted_state)
- for key, events in conflicted_state.items():
- new_state[key] = self._resolve_state_events(events)
+ resolved_state = self._resolve_state_events(
+ conflicted_state, auth_events
+ )
except:
logger.exception("Failed to resolve state")
raise
- defer.returnValue((None, new_state, prev_states))
-
- def _get_power_level_from_event_state(self, event, user_id):
- if hasattr(event, "old_state_events") and event.old_state_events:
- key = (EventTypes.PowerLevels, "", )
- power_level_event = event.old_state_events.get(key)
- level = None
- if power_level_event:
- level = power_level_event.content.get("users", {}).get(
- user_id
- )
- if not level:
- level = power_level_event.content.get("users_default", 0)
+ new_state = unconflicted_state
+ new_state.update(resolved_state)
- return level
- else:
- return 0
+ defer.returnValue((None, new_state, prev_states))
@log_function
- def _resolve_state_events(self, events):
- curr_events = events
-
- new_powers = [
- self._get_power_level_from_event_state(e, e.user_id)
- for e in curr_events
- ]
-
- new_powers = [
- int(p) if p else 0 for p in new_powers
- ]
+ def _resolve_state_events(self, conflicted_state, auth_events):
+ """ This is where we actually decide which of the conflicted state to
+ use.
+
+ We resolve conflicts in the following order:
+ 1. power levels
+ 2. memberships
+ 3. other events.
+ """
+ resolved_state = {}
+ power_key = (EventTypes.PowerLevels, "")
+ if power_key in conflicted_state.items():
+ power_levels = conflicted_state[power_key]
+ resolved_state[power_key] = self._resolve_auth_events(power_levels)
+
+ auth_events.update(resolved_state)
+
+ for key, events in conflicted_state.items():
+ if key[0] == EventTypes.Member:
+ resolved_state[key] = self._resolve_auth_events(
+ events,
+ auth_events
+ )
- max_power = max(new_powers)
+ auth_events.update(resolved_state)
- curr_events = [
- z[0] for z in zip(curr_events, new_powers)
- if z[1] == max_power
- ]
+ for key, events in conflicted_state.items():
+ if key not in resolved_state:
+ resolved_state[key] = self._resolve_normal_events(
+ events, auth_events
+ )
- if not curr_events:
- raise RuntimeError("Max didn't get a max?")
- elif len(curr_events) == 1:
- return curr_events[0]
-
- # TODO: For now, just choose the one with the largest event_id.
- return (
- sorted(
- curr_events,
- key=lambda e: hashlib.sha1(
- e.event_id + e.user_id + e.room_id + e.type
- ).hexdigest()
- )[0]
- )
+ return resolved_state
+
+ def _resolve_auth_events(self, events, auth_events):
+ reverse = [i for i in reversed(self._ordered_events(events))]
+
+ auth_events = dict(auth_events)
+
+ prev_event = reverse[0]
+ for event in reverse[1:]:
+ auth_events[(prev_event.type, prev_event.state_key)] = prev_event
+ try:
+ # FIXME: hs.get_auth() is bad style, but we need to do it to
+ # get around circular deps.
+ self.hs.get_auth().check(event, auth_events)
+ prev_event = event
+ except AuthError:
+ return prev_event
+
+ return event
+
+ def _resolve_normal_events(self, events, auth_events):
+ for event in self._ordered_events(events):
+ try:
+ # FIXME: hs.get_auth() is bad style, but we need to do it to
+ # get around circular deps.
+ self.hs.get_auth().check(event, auth_events)
+ return event
+ except AuthError:
+ pass
+
+ # Use the last event (the one with the least depth) if they all fail
+ # the auth check.
+ return event
+
+ def _ordered_events(self, events):
+ def key_func(e):
+ return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
+
+ return sorted(events, key=key_func)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index e86b981b47..9bbd553dfc 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -30,10 +30,14 @@ from .stream import StreamStore
from .transactions import TransactionStore
from .keys import KeyStore
from .event_federation import EventFederationStore
+from .pusher import PusherStore
+from .push_rule import PushRuleStore
from .media_repository import MediaRepositoryStore
+from .rejections import RejectionsStore
from .state import StateStore
from .signatures import SignatureStore
+from .filtering import FilteringStore
from syutil.base64util import decode_base64
from syutil.jsonutil import encode_canonical_json
@@ -61,14 +65,20 @@ SCHEMAS = [
"state",
"event_edges",
"event_signatures",
+ "pusher",
"media_repository",
+<<<<<<< HEAD
"application_services"
+=======
+ "filtering",
+ "rejections",
+>>>>>>> develop
]
# Remember to update this number every time an incompatible change is made to
# database schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 11
+SCHEMA_VERSION = 12
class _RollbackButIsFineException(Exception):
@@ -82,8 +92,17 @@ class DataStore(RoomMemberStore, RoomStore,
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
PresenceStore, TransactionStore,
DirectoryStore, KeyStore, StateStore, SignatureStore,
+<<<<<<< HEAD
EventFederationStore, MediaRepositoryStore,
ApplicationServiceStore
+=======
+ EventFederationStore,
+ MediaRepositoryStore,
+ RejectionsStore,
+ FilteringStore,
+ PusherStore,
+ PushRuleStore
+>>>>>>> develop
):
def __init__(self, hs):
@@ -226,6 +245,9 @@ class DataStore(RoomMemberStore, RoomStore,
if not outlier:
self._store_state_groups_txn(txn, event, context)
+ if context.rejected:
+ self._store_rejections_txn(txn, event.event_id, context.rejected)
+
if current_state:
txn.execute(
"DELETE FROM current_state_events WHERE room_id = ?",
@@ -264,7 +286,7 @@ class DataStore(RoomMemberStore, RoomStore,
or_replace=True,
)
- if is_new_state:
+ if is_new_state and not context.rejected:
self._simple_insert_txn(
txn,
"current_state_events",
@@ -290,7 +312,7 @@ class DataStore(RoomMemberStore, RoomStore,
or_ignore=True,
)
- if not backfilled:
+ if not backfilled and not context.rejected:
self._simple_insert_txn(
txn,
table="state_forward_extremities",
@@ -372,9 +394,12 @@ class DataStore(RoomMemberStore, RoomStore,
"redacted": del_sql,
}
- if event_type:
+ if event_type and state_key is not None:
sql += " AND s.type = ? AND s.state_key = ? "
args = (room_id, event_type, state_key)
+ elif event_type:
+ sql += " AND s.type = ?"
+ args = (room_id, event_type)
else:
args = (room_id, )
@@ -384,6 +409,41 @@ class DataStore(RoomMemberStore, RoomStore,
defer.returnValue(events)
@defer.inlineCallbacks
+ def get_room_name_and_aliases(self, room_id):
+ del_sql = (
+ "SELECT event_id FROM redactions WHERE redacts = e.event_id "
+ "LIMIT 1"
+ )
+
+ sql = (
+ "SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
+ "INNER JOIN current_state_events as c ON e.event_id = c.event_id "
+ "INNER JOIN state_events as s ON e.event_id = s.event_id "
+ "WHERE c.room_id = ? "
+ ) % {
+ "redacted": del_sql,
+ }
+
+ sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
+ sql += " OR s.type = 'm.room.aliases')"
+ args = (room_id,)
+
+ results = yield self._execute_and_decode(sql, *args)
+
+ events = yield self._parse_events(results)
+
+ name = None
+ aliases = []
+
+ for e in events:
+ if e.type == 'm.room.name':
+ name = e.content['name']
+ elif e.type == 'm.room.aliases':
+ aliases.extend(e.content['aliases'])
+
+ defer.returnValue((name, aliases))
+
+ @defer.inlineCallbacks
def _get_min_token(self):
row = yield self._execute(
None,
@@ -419,6 +479,35 @@ class DataStore(RoomMemberStore, RoomStore,
],
)
+ def have_events(self, event_ids):
+ """Given a list of event ids, check if we have already processed them.
+
+ Returns:
+ dict: Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps to
+ None.
+ """
+ def f(txn):
+ sql = (
+ "SELECT e.event_id, reason FROM events as e "
+ "LEFT JOIN rejections as r ON e.event_id = r.event_id "
+ "WHERE e.event_id = ?"
+ )
+
+ res = {}
+ for event_id in event_ids:
+ txn.execute(sql, (event_id,))
+ row = txn.fetchone()
+ if row:
+ _, rejected = row
+ res[event_id] = rejected
+
+ return res
+
+ return self.runInteraction(
+ "have_events", f,
+ )
+
def schema_path(schema):
""" Get a filesystem path for the named database schema
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index f660fc6eaf..b350fd61f1 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -193,6 +193,50 @@ class SQLBaseStore(object):
txn.execute(sql, values.values())
return txn.lastrowid
+ def _simple_upsert(self, table, keyvalues, values):
+ """
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ Returns: A deferred
+ """
+ return self.runInteraction(
+ "_simple_upsert",
+ self._simple_upsert_txn, table, keyvalues, values
+ )
+
+ def _simple_upsert_txn(self, txn, table, keyvalues, values):
+ # Try to update
+ sql = "UPDATE %s SET %s WHERE %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in values),
+ " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ )
+ sqlargs = values.values() + keyvalues.values()
+ logger.debug(
+ "[SQL] %s Args=%s",
+ sql, sqlargs,
+ )
+
+ txn.execute(sql, sqlargs)
+ if txn.rowcount == 0:
+ # We didn't update and rows so insert a new one
+ allvalues = {}
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+
+ sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues)
+ )
+ logger.debug(
+ "[SQL] %s Args=%s",
+ sql, keyvalues.values(),
+ )
+ txn.execute(sql, allvalues.values())
+
def _simple_select_one(self, table, keyvalues, retcols,
allow_none=False):
"""Executes a SELECT query on the named table, which is expected to
@@ -344,8 +388,8 @@ class SQLBaseStore(object):
if updatevalues:
update_sql = "UPDATE %s SET %s WHERE %s" % (
table,
- ", ".join("%s = ?" % (k) for k in updatevalues),
- " AND ".join("%s = ?" % (k) for k in keyvalues)
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ " AND ".join("%s = ?" % (k,) for k in keyvalues)
)
def func(txn):
@@ -458,10 +502,12 @@ class SQLBaseStore(object):
return [e for e in events if e]
def _get_event_txn(self, txn, event_id, check_redacted=True,
- get_prev_content=False):
+ get_prev_content=False, allow_rejected=False):
sql = (
- "SELECT internal_metadata, json, r.event_id FROM event_json as e "
+ "SELECT e.internal_metadata, e.json, r.event_id, rej.reason "
+ "FROM event_json as e "
"LEFT JOIN redactions as r ON e.event_id = r.redacts "
+ "LEFT JOIN rejections as rej on rej.event_id = e.event_id "
"WHERE e.event_id = ? "
"LIMIT 1 "
)
@@ -473,13 +519,16 @@ class SQLBaseStore(object):
if not res:
return None
- internal_metadata, js, redacted = res
+ internal_metadata, js, redacted, rejected_reason = res
- return self._get_event_from_row_txn(
- txn, internal_metadata, js, redacted,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- )
+ if allow_rejected or not rejected_reason:
+ return self._get_event_from_row_txn(
+ txn, internal_metadata, js, redacted,
+ check_redacted=check_redacted,
+ get_prev_content=get_prev_content,
+ )
+ else:
+ return None
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False):
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
new file mode 100644
index 0000000000..e86eeced45
--- /dev/null
+++ b/synapse/storage/filtering.py
@@ -0,0 +1,63 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from ._base import SQLBaseStore
+
+import json
+
+
+class FilteringStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def get_user_filter(self, user_localpart, filter_id):
+ def_json = yield self._simple_select_one_onecol(
+ table="user_filters",
+ keyvalues={
+ "user_id": user_localpart,
+ "filter_id": filter_id,
+ },
+ retcol="filter_json",
+ allow_none=False,
+ )
+
+ defer.returnValue(json.loads(def_json))
+
+ def add_user_filter(self, user_localpart, user_filter):
+ def_json = json.dumps(user_filter)
+
+ # Need an atomic transaction to SELECT the maximal ID so far then
+ # INSERT a new one
+ def _do_txn(txn):
+ sql = (
+ "SELECT MAX(filter_id) FROM user_filters "
+ "WHERE user_id = ?"
+ )
+ txn.execute(sql, (user_localpart,))
+ max_id = txn.fetchone()[0]
+ if max_id is None:
+ filter_id = 0
+ else:
+ filter_id = max_id + 1
+
+ sql = (
+ "INSERT INTO user_filters (user_id, filter_id, filter_json)"
+ "VALUES(?, ?, ?)"
+ )
+ txn.execute(sql, (user_localpart, filter_id, def_json))
+
+ return filter_id
+
+ return self.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
new file mode 100644
index 0000000000..27502d2399
--- /dev/null
+++ b/synapse/storage/push_rule.py
@@ -0,0 +1,213 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+
+from ._base import SQLBaseStore, Table
+from twisted.internet import defer
+
+import logging
+import copy
+import json
+
+logger = logging.getLogger(__name__)
+
+
+class PushRuleStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def get_push_rules_for_user_name(self, user_name):
+ sql = (
+ "SELECT "+",".join(PushRuleTable.fields)+" "
+ "FROM "+PushRuleTable.table_name+" "
+ "WHERE user_name = ? "
+ "ORDER BY priority_class DESC, priority DESC"
+ )
+ rows = yield self._execute(None, sql, user_name)
+
+ dicts = []
+ for r in rows:
+ d = {}
+ for i, f in enumerate(PushRuleTable.fields):
+ d[f] = r[i]
+ dicts.append(d)
+
+ defer.returnValue(dicts)
+
+ @defer.inlineCallbacks
+ def add_push_rule(self, before, after, **kwargs):
+ vals = copy.copy(kwargs)
+ if 'conditions' in vals:
+ vals['conditions'] = json.dumps(vals['conditions'])
+ if 'actions' in vals:
+ vals['actions'] = json.dumps(vals['actions'])
+ # we could check the rest of the keys are valid column names
+ # but sqlite will do that anyway so I think it's just pointless.
+ if 'id' in vals:
+ del vals['id']
+
+ if before or after:
+ ret = yield self.runInteraction(
+ "_add_push_rule_relative_txn",
+ self._add_push_rule_relative_txn,
+ before=before,
+ after=after,
+ **vals
+ )
+ defer.returnValue(ret)
+ else:
+ ret = yield self.runInteraction(
+ "_add_push_rule_highest_priority_txn",
+ self._add_push_rule_highest_priority_txn,
+ **vals
+ )
+ defer.returnValue(ret)
+
+ def _add_push_rule_relative_txn(self, txn, user_name, **kwargs):
+ after = None
+ relative_to_rule = None
+ if 'after' in kwargs and kwargs['after']:
+ after = kwargs['after']
+ relative_to_rule = after
+ if 'before' in kwargs and kwargs['before']:
+ relative_to_rule = kwargs['before']
+
+ # get the priority of the rule we're inserting after/before
+ sql = (
+ "SELECT priority_class, priority FROM ? "
+ "WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,)
+ )
+ txn.execute(sql, (user_name, relative_to_rule))
+ res = txn.fetchall()
+ if not res:
+ raise RuleNotFoundException("before/after rule not found: %s" % (relative_to_rule))
+ priority_class, base_rule_priority = res[0]
+
+ if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class:
+ raise InconsistentRuleException(
+ "Given priority class does not match class of relative rule"
+ )
+
+ new_rule = copy.copy(kwargs)
+ if 'before' in new_rule:
+ del new_rule['before']
+ if 'after' in new_rule:
+ del new_rule['after']
+ new_rule['priority_class'] = priority_class
+ new_rule['user_name'] = user_name
+
+ # check if the priority before/after is free
+ new_rule_priority = base_rule_priority
+ if after:
+ new_rule_priority -= 1
+ else:
+ new_rule_priority += 1
+
+ new_rule['priority'] = new_rule_priority
+
+ sql = (
+ "SELECT COUNT(*) FROM " + PushRuleTable.table_name +
+ " WHERE user_name = ? AND priority_class = ? AND priority = ?"
+ )
+ txn.execute(sql, (user_name, priority_class, new_rule_priority))
+ res = txn.fetchall()
+ num_conflicting = res[0][0]
+
+ # if there are conflicting rules, bump everything
+ if num_conflicting:
+ sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority "
+ if after:
+ sql += "-1"
+ else:
+ sql += "+1"
+ sql += " WHERE user_name = ? AND priority_class = ? AND priority "
+ if after:
+ sql += "<= ?"
+ else:
+ sql += ">= ?"
+
+ txn.execute(sql, (user_name, priority_class, new_rule_priority))
+
+ # now insert the new rule
+ sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
+ sql += ",".join(new_rule.keys())+") VALUES ("
+ sql += ", ".join(["?" for _ in new_rule.keys()])+")"
+
+ txn.execute(sql, new_rule.values())
+
+ def _add_push_rule_highest_priority_txn(self, txn, user_name,
+ priority_class, **kwargs):
+ # find the highest priority rule in that class
+ sql = (
+ "SELECT COUNT(*), MAX(priority) FROM " + PushRuleTable.table_name +
+ " WHERE user_name = ? and priority_class = ?"
+ )
+ txn.execute(sql, (user_name, priority_class))
+ res = txn.fetchall()
+ (how_many, highest_prio) = res[0]
+
+ new_prio = 0
+ if how_many > 0:
+ new_prio = highest_prio + 1
+
+ # and insert the new rule
+ new_rule = copy.copy(kwargs)
+ if 'id' in new_rule:
+ del new_rule['id']
+ new_rule['user_name'] = user_name
+ new_rule['priority_class'] = priority_class
+ new_rule['priority'] = new_prio
+
+ sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" ("
+ sql += ",".join(new_rule.keys())+") VALUES ("
+ sql += ", ".join(["?" for _ in new_rule.keys()])+")"
+
+ txn.execute(sql, new_rule.values())
+
+ @defer.inlineCallbacks
+ def delete_push_rule(self, user_name, rule_id, **kwargs):
+ """
+ Delete a push rule. Args specify the row to be deleted and can be
+ any of the columns in the push_rule table, but below are the
+ standard ones
+
+ Args:
+ user_name (str): The matrix ID of the push rule owner
+ rule_id (str): The rule_id of the rule to be deleted
+ """
+ yield self._simple_delete_one(PushRuleTable.table_name, kwargs)
+
+
+class RuleNotFoundException(Exception):
+ pass
+
+
+class InconsistentRuleException(Exception):
+ pass
+
+
+class PushRuleTable(Table):
+ table_name = "push_rules"
+
+ fields = [
+ "id",
+ "user_name",
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ]
+
+ EntryType = collections.namedtuple("PushRuleEntry", fields)
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
new file mode 100644
index 0000000000..f253c9e2c3
--- /dev/null
+++ b/synapse/storage/pusher.py
@@ -0,0 +1,173 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+
+from ._base import SQLBaseStore, Table
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class PusherStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey):
+ sql = (
+ "SELECT id, user_name, kind, instance_handle, app_id,"
+ "app_display_name, device_display_name, pushkey, ts, data, "
+ "last_token, last_success, failing_since "
+ "FROM pushers "
+ "WHERE app_id = ? AND pushkey = ?"
+ )
+
+ rows = yield self._execute(
+ None, sql, app_id_and_pushkey[0], app_id_and_pushkey[1]
+ )
+
+ ret = [
+ {
+ "id": r[0],
+ "user_name": r[1],
+ "kind": r[2],
+ "instance_handle": r[3],
+ "app_id": r[4],
+ "app_display_name": r[5],
+ "device_display_name": r[6],
+ "pushkey": r[7],
+ "pushkey_ts": r[8],
+ "data": r[9],
+ "last_token": r[10],
+ "last_success": r[11],
+ "failing_since": r[12]
+ }
+ for r in rows
+ ]
+
+ defer.returnValue(ret[0])
+
+ @defer.inlineCallbacks
+ def get_all_pushers(self):
+ sql = (
+ "SELECT id, user_name, kind, instance_handle, app_id,"
+ "app_display_name, device_display_name, pushkey, ts, data, "
+ "last_token, last_success, failing_since "
+ "FROM pushers"
+ )
+
+ rows = yield self._execute(None, sql)
+
+ ret = [
+ {
+ "id": r[0],
+ "user_name": r[1],
+ "kind": r[2],
+ "instance_handle": r[3],
+ "app_id": r[4],
+ "app_display_name": r[5],
+ "device_display_name": r[6],
+ "pushkey": r[7],
+ "pushkey_ts": r[8],
+ "data": r[9],
+ "last_token": r[10],
+ "last_success": r[11],
+ "failing_since": r[12]
+ }
+ for r in rows
+ ]
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def add_pusher(self, user_name, instance_handle, kind, app_id,
+ app_display_name, device_display_name,
+ pushkey, pushkey_ts, lang, data):
+ try:
+ yield self._simple_upsert(
+ PushersTable.table_name,
+ dict(
+ app_id=app_id,
+ pushkey=pushkey,
+ ),
+ dict(
+ user_name=user_name,
+ kind=kind,
+ instance_handle=instance_handle,
+ app_display_name=app_display_name,
+ device_display_name=device_display_name,
+ ts=pushkey_ts,
+ lang=lang,
+ data=data
+ ))
+ except Exception as e:
+ logger.error("create_pusher with failed: %s", e)
+ raise StoreError(500, "Problem creating pusher.")
+
+ @defer.inlineCallbacks
+ def delete_pusher_by_app_id_pushkey(self, app_id, pushkey):
+ yield self._simple_delete_one(
+ PushersTable.table_name,
+ dict(app_id=app_id, pushkey=pushkey)
+ )
+
+ @defer.inlineCallbacks
+ def update_pusher_last_token(self, user_name, pushkey, last_token):
+ yield self._simple_update_one(
+ PushersTable.table_name,
+ {'user_name': user_name, 'pushkey': pushkey},
+ {'last_token': last_token}
+ )
+
+ @defer.inlineCallbacks
+ def update_pusher_last_token_and_success(self, user_name, pushkey,
+ last_token, last_success):
+ yield self._simple_update_one(
+ PushersTable.table_name,
+ {'user_name': user_name, 'pushkey': pushkey},
+ {'last_token': last_token, 'last_success': last_success}
+ )
+
+ @defer.inlineCallbacks
+ def update_pusher_failing_since(self, user_name, pushkey, failing_since):
+ yield self._simple_update_one(
+ PushersTable.table_name,
+ {'user_name': user_name, 'pushkey': pushkey},
+ {'failing_since': failing_since}
+ )
+
+
+class PushersTable(Table):
+ table_name = "pushers"
+
+ fields = [
+ "id",
+ "user_name",
+ "kind",
+ "instance_handle",
+ "app_id",
+ "app_display_name",
+ "device_display_name",
+ "pushkey",
+ "pushkey_ts",
+ "data",
+ "last_token",
+ "last_success",
+ "failing_since"
+ ]
+
+ EntryType = collections.namedtuple("PusherEntry", fields)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 75dffa4db2..029b07cc66 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -122,7 +122,8 @@ class RegistrationStore(SQLBaseStore):
def _query_for_auth(self, txn, token):
sql = (
- "SELECT users.name, users.admin, access_tokens.device_id"
+ "SELECT users.name, users.admin,"
+ " access_tokens.device_id, access_tokens.id as token_id"
" FROM users"
" INNER JOIN access_tokens on users.id = access_tokens.user_id"
" WHERE token = ?"
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
new file mode 100644
index 0000000000..4e1a9a2783
--- /dev/null
+++ b/synapse/storage/rejections.py
@@ -0,0 +1,43 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import SQLBaseStore
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class RejectionsStore(SQLBaseStore):
+ def _store_rejections_txn(self, txn, event_id, reason):
+ self._simple_insert_txn(
+ txn,
+ table="rejections",
+ values={
+ "event_id": event_id,
+ "reason": reason,
+ "last_check": self._clock.time_msec(),
+ }
+ )
+
+ def get_rejection_reason(self, event_id):
+ return self._simple_select_one_onecol(
+ table="rejections",
+ retcol="reason",
+ keyvalues={
+ "event_id": event_id,
+ },
+ allow_none=True,
+ )
diff --git a/synapse/storage/schema/delta/v12.sql b/synapse/storage/schema/delta/v12.sql
new file mode 100644
index 0000000000..a6867cba62
--- /dev/null
+++ b/synapse/storage/schema/delta/v12.sql
@@ -0,0 +1,54 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS rejections(
+ event_id TEXT NOT NULL,
+ reason TEXT NOT NULL,
+ last_check TEXT NOT NULL,
+ CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
+);
+
+-- Push notification endpoints that users have configured
+CREATE TABLE IF NOT EXISTS pushers (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_name TEXT NOT NULL,
+ instance_handle varchar(32) NOT NULL,
+ kind varchar(8) NOT NULL,
+ app_id varchar(64) NOT NULL,
+ app_display_name varchar(64) NOT NULL,
+ device_display_name varchar(128) NOT NULL,
+ pushkey blob NOT NULL,
+ ts BIGINT NOT NULL,
+ lang varchar(8),
+ data blob,
+ last_token TEXT,
+ last_success BIGINT,
+ failing_since BIGINT,
+ FOREIGN KEY(user_name) REFERENCES users(name),
+ UNIQUE (app_id, pushkey)
+);
+
+CREATE TABLE IF NOT EXISTS push_rules (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_name TEXT NOT NULL,
+ rule_id TEXT NOT NULL,
+ priority_class TINYINT NOT NULL,
+ priority INTEGER NOT NULL DEFAULT 0,
+ conditions TEXT NOT NULL,
+ actions TEXT NOT NULL,
+ UNIQUE(user_name, rule_id)
+);
+
+CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);
diff --git a/synapse/storage/schema/delta/v13.sql b/synapse/storage/schema/delta/v13.sql
new file mode 100644
index 0000000000..beb39ca201
--- /dev/null
+++ b/synapse/storage/schema/delta/v13.sql
@@ -0,0 +1,24 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS user_filters(
+ user_id TEXT,
+ filter_id INTEGER,
+ filter_json TEXT,
+ FOREIGN KEY(user_id) REFERENCES users(id)
+);
+
+CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
+ user_id, filter_id
+);
diff --git a/synapse/storage/schema/filtering.sql b/synapse/storage/schema/filtering.sql
new file mode 100644
index 0000000000..beb39ca201
--- /dev/null
+++ b/synapse/storage/schema/filtering.sql
@@ -0,0 +1,24 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE IF NOT EXISTS user_filters(
+ user_id TEXT,
+ filter_id INTEGER,
+ filter_json TEXT,
+ FOREIGN KEY(user_id) REFERENCES users(id)
+);
+
+CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters(
+ user_id, filter_id
+);
diff --git a/synapse/storage/schema/pusher.sql b/synapse/storage/schema/pusher.sql
new file mode 100644
index 0000000000..8c4dfd5c1b
--- /dev/null
+++ b/synapse/storage/schema/pusher.sql
@@ -0,0 +1,46 @@
+/* Copyright 2014 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- Push notification endpoints that users have configured
+CREATE TABLE IF NOT EXISTS pushers (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_name TEXT NOT NULL,
+ instance_handle varchar(32) NOT NULL,
+ kind varchar(8) NOT NULL,
+ app_id varchar(64) NOT NULL,
+ app_display_name varchar(64) NOT NULL,
+ device_display_name varchar(128) NOT NULL,
+ pushkey blob NOT NULL,
+ ts BIGINT NOT NULL,
+ lang varchar(8),
+ data blob,
+ last_token TEXT,
+ last_success BIGINT,
+ failing_since BIGINT,
+ FOREIGN KEY(user_name) REFERENCES users(name),
+ UNIQUE (app_id, pushkey)
+);
+
+CREATE TABLE IF NOT EXISTS push_rules (
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
+ user_name TEXT NOT NULL,
+ rule_id TEXT NOT NULL,
+ priority_class TINYINT NOT NULL,
+ priority INTEGER NOT NULL DEFAULT 0,
+ conditions TEXT NOT NULL,
+ actions TEXT NOT NULL,
+ UNIQUE(user_name, rule_id)
+);
+
+CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name);
diff --git a/synapse/storage/schema/rejections.sql b/synapse/storage/schema/rejections.sql
new file mode 100644
index 0000000000..bd2a8b1bb5
--- /dev/null
+++ b/synapse/storage/schema/rejections.sql
@@ -0,0 +1,21 @@
+/* Copyright 2015 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS rejections(
+ event_id TEXT NOT NULL,
+ reason TEXT NOT NULL,
+ last_check TEXT NOT NULL,
+ CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
+);
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 8ac2adab05..3ccb6f8a61 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -82,10 +82,10 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
def parse(cls, string):
try:
if string[0] == 's':
- return cls(None, int(string[1:]))
+ return cls(topological=None, stream=int(string[1:]))
if string[0] == 't':
parts = string[1:].split('-', 1)
- return cls(int(parts[1]), int(parts[0]))
+ return cls(topological=int(parts[0]), stream=int(parts[1]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -94,7 +94,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")):
def parse_stream_token(cls, string):
try:
if string[0] == 's':
- return cls(None, int(string[1:]))
+ return cls(topological=None, stream=int(string[1:]))
except:
pass
raise SynapseError(400, "Invalid token %r" % (string,))
@@ -181,8 +181,11 @@ class StreamStore(SQLBaseStore):
get_prev_content=True
)
+ self._set_before_and_after(ret, rows)
+
if rows:
key = "s%d" % max([r["stream_ordering"] for r in rows])
+
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -260,22 +263,44 @@ class StreamStore(SQLBaseStore):
get_prev_content=True
)
+ self._set_before_and_after(events, rows)
+
return events, next_token,
return self.runInteraction("paginate_room_events", f)
def get_recent_events_for_room(self, room_id, limit, end_token,
- with_feedback=False):
+ with_feedback=False, from_token=None):
# TODO (erikj): Handle compressed feedback
- sql = (
- "SELECT stream_ordering, topological_ordering, event_id FROM events "
- "WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0 "
- "ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ? "
- )
+ end_token = _StreamToken.parse_stream_token(end_token)
- def f(txn):
- txn.execute(sql, (room_id, end_token, limit,))
+ if from_token is None:
+ sql = (
+ "SELECT stream_ordering, topological_ordering, event_id"
+ " FROM events"
+ " WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0"
+ " ORDER BY topological_ordering DESC, stream_ordering DESC"
+ " LIMIT ?"
+ )
+ else:
+ from_token = _StreamToken.parse_stream_token(from_token)
+ sql = (
+ "SELECT stream_ordering, topological_ordering, event_id"
+ " FROM events"
+ " WHERE room_id = ? AND stream_ordering > ?"
+ " AND stream_ordering <= ? AND outlier = 0"
+ " ORDER BY topological_ordering DESC, stream_ordering DESC"
+ " LIMIT ?"
+ )
+
+ def get_recent_events_for_room_txn(txn):
+ if from_token is None:
+ txn.execute(sql, (room_id, end_token.stream, limit,))
+ else:
+ txn.execute(sql, (
+ room_id, from_token.stream, end_token.stream, limit
+ ))
rows = self.cursor_to_dict(txn)
@@ -291,9 +316,9 @@ class StreamStore(SQLBaseStore):
toke = rows[0]["stream_ordering"] - 1
start_token = str(_StreamToken(topo, toke))
- token = (start_token, end_token)
+ token = (start_token, str(end_token))
else:
- token = (end_token, end_token)
+ token = (str(end_token), str(end_token))
events = self._get_events_txn(
txn,
@@ -301,9 +326,13 @@ class StreamStore(SQLBaseStore):
get_prev_content=True
)
+ self._set_before_and_after(events, rows)
+
return events, token
- return self.runInteraction("get_recent_events_for_room", f)
+ return self.runInteraction(
+ "get_recent_events_for_room", get_recent_events_for_room_txn
+ )
def get_room_events_max_id(self):
return self.runInteraction(
@@ -325,3 +354,12 @@ class StreamStore(SQLBaseStore):
key = res[0]["m"]
return "s%d" % (key,)
+
+ @staticmethod
+ def _set_before_and_after(events, rows):
+ for event, row in zip(events, rows):
+ stream = row["stream_ordering"]
+ topo = event.depth
+ internal = event.internal_metadata
+ internal.before = str(_StreamToken(topo, stream - 1))
+ internal.after = str(_StreamToken(topo, stream))
diff --git a/synapse/types.py b/synapse/types.py
index faac729ff2..f6a1b0bbcf 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -119,3 +119,6 @@ class StreamToken(
d = self._asdict()
d[key] = new_value
return StreamToken(**d)
+
+
+ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id"))
|