diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 59db76debc..0db26fcfd7 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -675,27 +675,18 @@ class Auth(object):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
- user_prefix = "user_id = "
- user = None
- user_id = None
- guest = False
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(user_prefix):
- user_id = caveat.caveat_id[len(user_prefix):]
- user = UserID.from_string(user_id)
- elif caveat.caveat_id == "guest = true":
- guest = True
+ user_id = self.get_user_id_from_macaroon(macaroon)
+ user = UserID.from_string(user_id)
self.validate_macaroon(
macaroon, rights, self.hs.config.expire_access_token,
user_id=user_id,
)
- if user is None:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
- errcode=Codes.UNKNOWN_TOKEN
- )
+ guest = False
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id == "guest = true":
+ guest = True
if guest:
ret = {
@@ -743,6 +734,29 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN
)
+ def get_user_id_from_macaroon(self, macaroon):
+ """Retrieve the user_id given by the caveats on the macaroon.
+
+ Does *not* validate the macaroon.
+
+ Args:
+ macaroon (pymacaroons.Macaroon): The macaroon to validate
+
+ Returns:
+ (str) user id
+
+ Raises:
+ AuthError if there is no user_id caveat in the macaroon
+ """
+ user_prefix = "user_id = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(user_prefix):
+ return caveat.caveat_id[len(user_prefix):]
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@@ -754,6 +768,7 @@ class Auth(object):
verify_expiry(bool): Whether to verify whether the macaroon has expired.
This should really always be True, but no clients currently implement
token refresh, so we can't enforce expiry yet.
+ user_id (str): The user_id required
"""
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index c8dde0fcb8..8d755a4b33 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -80,11 +80,6 @@ class PusherSlaveStore(
DataStore.get_profile_displayname.__func__
)
- # XXX: This is a bit broken because we don't persist forgotten rooms
- # in a way that they can be streamed. This means that we don't have a
- # way to invalidate the forgotten rooms cache correctly.
- # For now we expire the cache every 10 minutes.
- BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
@@ -168,7 +163,6 @@ class PusherServer(HomeServer):
store = self.get_datastore()
replication_url = self.config.worker_replication_url
pusher_pool = self.get_pusherpool()
- clock = self.get_clock()
def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
@@ -220,21 +214,11 @@ class PusherServer(HomeServer):
min_stream_id, max_stream_id, affected_room_ids
)
- def expire_broken_caches():
- store.who_forgot_in_room.invalidate_all()
-
- next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
- now_ms = clock.time_msec()
- if now_ms > next_expire_broken_caches_ms:
- expire_broken_caches()
- next_expire_broken_caches_ms = (
- now_ms + store.BROKEN_CACHE_EXPIRY_MS
- )
yield store.process_replication(result)
poke_pushers(result)
except:
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 215ccfd522..e3173533e2 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -26,6 +26,7 @@ from synapse.http.site import SynapseSite
from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
+from synapse.rest.client.v1 import events
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@@ -74,11 +75,6 @@ class SynchrotronSlavedStore(
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
- # XXX: This is a bit broken because we don't persist forgotten rooms
- # in a way that they can be streamed. This means that we don't have a
- # way to invalidate the forgotten rooms cache correctly.
- # For now we expire the cache every 10 minutes.
- BROKEN_CACHE_EXPIRY_MS = 60 * 60 * 1000
who_forgot_in_room = (
RoomMemberStore.__dict__["who_forgot_in_room"]
)
@@ -89,17 +85,23 @@ class SynchrotronSlavedStore(
get_presence_list_accepted = PresenceStore.__dict__[
"get_presence_list_accepted"
]
+ get_presence_list_observers_accepted = PresenceStore.__dict__[
+ "get_presence_list_observers_accepted"
+ ]
+
UPDATE_SYNCING_USERS_MS = 10 * 1000
class SynchrotronPresence(object):
def __init__(self, hs):
+ self.is_mine_id = hs.is_mine_id
self.http_client = hs.get_simple_http_client()
self.store = hs.get_datastore()
self.user_to_num_current_syncs = {}
self.syncing_users_url = hs.config.worker_replication_url + "/syncing_users"
self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info()
self.user_to_current_state = {
@@ -119,11 +121,13 @@ class SynchrotronPresence(object):
reactor.addSystemEventTrigger("before", "shutdown", self._on_shutdown)
- def set_state(self, user, state):
+ def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work?
pass
get_states = PresenceHandler.get_states.__func__
+ get_state = PresenceHandler.get_state.__func__
+ _get_interested_parties = PresenceHandler._get_interested_parties.__func__
current_state_for_users = PresenceHandler.current_state_for_users.__func__
@defer.inlineCallbacks
@@ -194,19 +198,39 @@ class SynchrotronPresence(object):
self._need_to_send_sync = False
yield self._send_syncing_users_now()
+ @defer.inlineCallbacks
+ def notify_from_replication(self, states, stream_id):
+ parties = yield self._get_interested_parties(
+ states, calculate_remote_hosts=False
+ )
+ room_ids_to_states, users_to_states, _ = parties
+
+ self.notifier.on_new_event(
+ "presence_key", stream_id, rooms=room_ids_to_states.keys(),
+ users=users_to_states.keys()
+ )
+
+ @defer.inlineCallbacks
def process_replication(self, result):
stream = result.get("presence", {"rows": []})
+ states = []
for row in stream["rows"]:
(
position, user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
) = row
- self.user_to_current_state[user_id] = UserPresenceState(
+ state = UserPresenceState(
user_id, state, last_active_ts,
last_federation_update_ts, last_user_sync_ts, status_msg,
currently_active
)
+ self.user_to_current_state[user_id] = state
+ states.append(state)
+
+ if states and "position" in stream:
+ stream_id = int(stream["position"])
+ yield self.notify_from_replication(states, stream_id)
class SynchrotronTyping(object):
@@ -266,10 +290,12 @@ class SynchrotronServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
+ events.register_servlets(self, resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
"/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
})
root_resource = create_resource_tree(resources, Resource())
@@ -307,15 +333,10 @@ class SynchrotronServer(HomeServer):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.worker_replication_url
- clock = self.get_clock()
notifier = self.get_notifier()
presence_handler = self.get_presence_handler()
typing_handler = self.get_typing_handler()
- def expire_broken_caches():
- store.who_forgot_in_room.invalidate_all()
- store.get_presence_list_accepted.invalidate_all()
-
def notify_from_stream(
result, stream_name, stream_key, room=None, user=None
):
@@ -377,22 +398,15 @@ class SynchrotronServer(HomeServer):
result, "typing", "typing_key", room="room_id"
)
- next_expire_broken_caches_ms = 0
while True:
try:
args = store.stream_positions()
args.update(typing_handler.stream_positions())
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
- now_ms = clock.time_msec()
- if now_ms > next_expire_broken_caches_ms:
- expire_broken_caches()
- next_expire_broken_caches_ms = (
- now_ms + store.BROKEN_CACHE_EXPIRY_MS
- )
yield store.process_replication(result)
typing_handler.process_replication(result)
- presence_handler.process_replication(result)
+ yield presence_handler.process_replication(result)
notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5012c10ee8..7cd11cfae7 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -61,6 +61,10 @@ Attributes:
"""
+class KeyLookupError(ValueError):
+ pass
+
+
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -363,7 +367,7 @@ class Keyring(object):
)
except Exception as e:
logger.info(
- "Unable to getting key %r for %r directly: %s %s",
+ "Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
@@ -425,7 +429,7 @@ class Keyring(object):
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
- raise ValueError(
+ raise KeyLookupError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
)
@@ -448,7 +452,7 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]),
list(perspective_keys)
)
- raise ValueError(
+ raise KeyLookupError(
"Response not signed with a known key for perspective"
" server %r" % (perspective_name,)
)
@@ -491,10 +495,10 @@ class Keyring(object):
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
- raise ValueError("Key response not signed by remote server")
+ raise KeyLookupError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
- raise ValueError("Key response missing TLS fingerprints")
+ raise KeyLookupError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
@@ -508,7 +512,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
- raise ValueError("TLS certificate not allowed by fingerprints")
+ raise KeyLookupError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
from_server=server_name,
@@ -560,14 +564,14 @@ class Keyring(object):
server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
- raise ValueError(
+ raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
- raise ValueError(
+ raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
@@ -635,15 +639,15 @@ class Keyring(object):
if ("signatures" not in response
or server_name not in response["signatures"]):
- raise ValueError("Key response not signed by remote server")
+ raise KeyLookupError("Key response not signed by remote server")
if "tls_certificate" not in response:
- raise ValueError("Key response missing TLS certificate")
+ raise KeyLookupError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
- raise ValueError("TLS certificate doesn't match")
+ raise KeyLookupError("TLS certificate doesn't match")
# Cache the result in the datastore.
@@ -659,7 +663,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]:
- raise ValueError(
+ raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index aab18d7f71..0e9fd902af 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -88,6 +88,8 @@ def prune_event(event):
if "age_ts" in event.unsigned:
allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
+ if "replaces_state" in event.unsigned:
+ allowed_fields["unsigned"]["replaces_state"] = event.unsigned["replaces_state"]
return type(event)(
allowed_fields,
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index da95c2ad6d..9ba3151713 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -51,10 +51,34 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+PDU_RETRY_TIME_MS = 1 * 60 * 1000
+
+
class FederationClient(FederationBase):
def __init__(self, hs):
super(FederationClient, self).__init__(hs)
+ self.pdu_destination_tried = {}
+ self._clock.looping_call(
+ self._clear_tried_cache, 60 * 1000,
+ )
+
+ def _clear_tried_cache(self):
+ """Clear pdu_destination_tried cache"""
+ now = self._clock.time_msec()
+
+ old_dict = self.pdu_destination_tried
+ self.pdu_destination_tried = {}
+
+ for event_id, destination_dict in old_dict.items():
+ destination_dict = {
+ dest: time
+ for dest, time in destination_dict.items()
+ if time + PDU_RETRY_TIME_MS > now
+ }
+ if destination_dict:
+ self.pdu_destination_tried[event_id] = destination_dict
+
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
@@ -240,8 +264,15 @@ class FederationClient(FederationBase):
if ev:
defer.returnValue(ev)
+ pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+
pdu = None
for destination in destinations:
+ now = self._clock.time_msec()
+ last_attempt = pdu_attempts.get(destination, 0)
+ if last_attempt + PDU_RETRY_TIME_MS > now:
+ continue
+
try:
limiter = yield get_retry_limiter(
destination,
@@ -269,25 +300,19 @@ class FederationClient(FederationBase):
break
- except SynapseError as e:
- logger.info(
- "Failed to get PDU %s from %s because %s",
- event_id, destination, e,
- )
- continue
- except CodeMessageException as e:
- if 400 <= e.code < 500:
- raise
+ pdu_attempts[destination] = now
+ except SynapseError as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
- continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
+ pdu_attempts[destination] = now
+
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
@@ -406,7 +431,7 @@ class FederationClient(FederationBase):
events and the second is a list of event ids that we failed to fetch.
"""
if return_local:
- seen_events = yield self.store.get_events(event_ids)
+ seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
seen_events = yield self.store.have_events(event_ids)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 5787f854d4..cb2ef0210c 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -21,11 +21,11 @@ from .units import Transaction
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
-from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
+from synapse.util.metrics import measure_func
import synapse.metrics
import logging
@@ -51,7 +51,7 @@ class TransactionQueue(object):
self.transport_layer = transport_layer
- self._clock = hs.get_clock()
+ self.clock = hs.get_clock()
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
@@ -82,7 +82,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {}
# HACK to get unique tx id
- self._next_txn_id = int(self._clock.time_msec())
+ self._next_txn_id = int(self.clock.time_msec())
def can_send_to(self, destination):
"""Can we send messages to the given server?
@@ -119,266 +119,215 @@ class TransactionQueue(object):
if not destinations:
return
- deferreds = []
-
for destination in destinations:
- deferred = defer.Deferred()
self.pending_pdus_by_dest.setdefault(destination, []).append(
- (pdu, deferred, order)
+ (pdu, order)
)
- def chain(failure):
- if not deferred.called:
- deferred.errback(failure)
-
- def log_failure(f):
- logger.warn("Failed to send pdu to %s: %s", destination, f.value)
-
- deferred.addErrback(log_failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(chain)
-
- deferreds.append(deferred)
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
- # NO inlineCallbacks
def enqueue_edu(self, edu):
destination = edu.destination
if not self.can_send_to(destination):
return
- deferred = defer.Deferred()
- self.pending_edus_by_dest.setdefault(destination, []).append(
- (edu, deferred)
- )
+ self.pending_edus_by_dest.setdefault(destination, []).append(edu)
- def chain(failure):
- if not deferred.called:
- deferred.errback(failure)
-
- def log_failure(f):
- logger.warn("Failed to send edu to %s: %s", destination, f.value)
-
- deferred.addErrback(log_failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(chain)
-
- return deferred
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
- @defer.inlineCallbacks
def enqueue_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
- deferred = defer.Deferred()
-
if not self.can_send_to(destination):
return
self.pending_failures_by_dest.setdefault(
destination, []
- ).append(
- (failure, deferred)
- )
-
- def chain(f):
- if not deferred.called:
- deferred.errback(f)
-
- def log_failure(f):
- logger.warn("Failed to send failure to %s: %s", destination, f.value)
-
- deferred.addErrback(log_failure)
-
- with PreserveLoggingContext():
- self._attempt_new_transaction(destination).addErrback(chain)
+ ).append(failure)
- yield deferred
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
@defer.inlineCallbacks
- @log_function
def _attempt_new_transaction(self, destination):
yield run_on_reactor()
+ while True:
+ # list of (pending_pdu, deferred, order)
+ if destination in self.pending_transactions:
+ # XXX: pending_transactions can get stuck on by a never-ending
+ # request at which point pending_pdus_by_dest just keeps growing.
+ # we need application-layer timeouts of some flavour of these
+ # requests
+ logger.debug(
+ "TX [%s] Transaction already in progress",
+ destination
+ )
+ return
- # list of (pending_pdu, deferred, order)
- if destination in self.pending_transactions:
- # XXX: pending_transactions can get stuck on by a never-ending
- # request at which point pending_pdus_by_dest just keeps growing.
- # we need application-layer timeouts of some flavour of these
- # requests
- logger.debug(
- "TX [%s] Transaction already in progress",
- destination
- )
- return
-
- pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
- pending_edus = self.pending_edus_by_dest.pop(destination, [])
- pending_failures = self.pending_failures_by_dest.pop(destination, [])
+ pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+ pending_edus = self.pending_edus_by_dest.pop(destination, [])
+ pending_failures = self.pending_failures_by_dest.pop(destination, [])
- if pending_pdus:
- logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- destination, len(pending_pdus))
+ if pending_pdus:
+ logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ destination, len(pending_pdus))
- if not pending_pdus and not pending_edus and not pending_failures:
- logger.debug("TX [%s] Nothing to send", destination)
- return
+ if not pending_pdus and not pending_edus and not pending_failures:
+ logger.debug("TX [%s] Nothing to send", destination)
+ return
- try:
- self.pending_transactions[destination] = 1
+ yield self._send_new_transaction(
+ destination, pending_pdus, pending_edus, pending_failures
+ )
- logger.debug("TX [%s] _attempt_new_transaction", destination)
+ @measure_func("_send_new_transaction")
+ @defer.inlineCallbacks
+ def _send_new_transaction(self, destination, pending_pdus, pending_edus,
+ pending_failures):
# Sort based on the order field
- pending_pdus.sort(key=lambda t: t[2])
-
+ pending_pdus.sort(key=lambda t: t[1])
pdus = [x[0] for x in pending_pdus]
- edus = [x[0] for x in pending_edus]
- failures = [x[0].get_dict() for x in pending_failures]
- deferreds = [
- x[1]
- for x in pending_pdus + pending_edus + pending_failures
- ]
-
- txn_id = str(self._next_txn_id)
-
- limiter = yield get_retry_limiter(
- destination,
- self._clock,
- self.store,
- )
+ edus = pending_edus
+ failures = [x.get_dict() for x in pending_failures]
- logger.debug(
- "TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d, failures: %d)",
- destination, txn_id,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures)
- )
+ try:
+ self.pending_transactions[destination] = 1
- logger.debug("TX [%s] Persisting transaction...", destination)
+ logger.debug("TX [%s] _attempt_new_transaction", destination)
- transaction = Transaction.create_new(
- origin_server_ts=int(self._clock.time_msec()),
- transaction_id=txn_id,
- origin=self.server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- pdu_failures=failures,
- )
+ txn_id = str(self._next_txn_id)
- self._next_txn_id += 1
+ limiter = yield get_retry_limiter(
+ destination,
+ self.clock,
+ self.store,
+ )
- yield self.transaction_actions.prepare_to_send(transaction)
+ logger.debug(
+ "TX [%s] {%s} Attempting new transaction"
+ " (pdus: %d, edus: %d, failures: %d)",
+ destination, txn_id,
+ len(pending_pdus),
+ len(pending_edus),
+ len(pending_failures)
+ )
- logger.debug("TX [%s] Persisted transaction", destination)
- logger.info(
- "TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d, failures: %d)",
- destination, txn_id,
- transaction.transaction_id,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures),
- )
+ logger.debug("TX [%s] Persisting transaction...", destination)
- with limiter:
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self._clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- try:
- response = yield self.transport_layer.send_transaction(
- transaction, json_data_cb
- )
- code = 200
-
- if response:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warn(
- "Transaction returned error for %s: %s",
- e_id, r,
- )
- except HttpResponseException as e:
- code = e.code
- response = e.response
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self.clock.time_msec()),
+ transaction_id=txn_id,
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ pdu_failures=failures,
+ )
+
+ self._next_txn_id += 1
+
+ yield self.transaction_actions.prepare_to_send(transaction)
+ logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
- "TX [%s] {%s} got %d response",
- destination, txn_id, code
+ "TX [%s] {%s} Sending transaction [%s],"
+ " (PDUs: %d, EDUs: %d, failures: %d)",
+ destination, txn_id,
+ transaction.transaction_id,
+ len(pending_pdus),
+ len(pending_edus),
+ len(pending_failures),
)
- logger.debug("TX [%s] Sent transaction", destination)
- logger.debug("TX [%s] Marking as delivered...", destination)
+ with limiter:
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self.clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ try:
+ response = yield self.transport_layer.send_transaction(
+ transaction, json_data_cb
+ )
+ code = 200
+
+ if response:
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
+ logger.warn(
+ "Transaction returned error for %s: %s",
+ e_id, r,
+ )
+ except HttpResponseException as e:
+ code = e.code
+ response = e.response
+
+ logger.info(
+ "TX [%s] {%s} got %d response",
+ destination, txn_id, code
+ )
- yield self.transaction_actions.delivered(
- transaction, code, response
- )
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
- logger.debug("TX [%s] Marked as delivered", destination)
-
- logger.debug("TX [%s] Yielding to callbacks...", destination)
-
- for deferred in deferreds:
- if code == 200:
- deferred.callback(None)
- else:
- deferred.errback(RuntimeError("Got status %d" % code))
-
- # Ensures we don't continue until all callbacks on that
- # deferred have fired
- try:
- yield deferred
- except:
- pass
-
- logger.debug("TX [%s] Yielded to callbacks", destination)
- except NotRetryingDestination:
- logger.info(
- "TX [%s] not ready for retry yet - "
- "dropping transaction for now",
- destination,
- )
- except RuntimeError as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
- except Exception as e:
- # We capture this here as there as nothing actually listens
- # for this finishing functions deferred.
- logger.warn(
- "TX [%s] Problem in _attempt_transaction: %s",
- destination,
- e,
- )
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
- for deferred in deferreds:
- if not deferred.called:
- deferred.errback(e)
+ logger.debug("TX [%s] Marked as delivered", destination)
+
+ if code != 200:
+ for p in pdus:
+ logger.info(
+ "Failed to send event %s to %s", p.event_id, destination
+ )
+ except NotRetryingDestination:
+ logger.info(
+ "TX [%s] not ready for retry yet - "
+ "dropping transaction for now",
+ destination,
+ )
+ except RuntimeError as e:
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.warn(
+ "TX [%s] Problem in _attempt_transaction: %s",
+ destination,
+ e,
+ )
+
+ for p in pdus:
+ logger.info("Failed to send event %s to %s", p.event_id, destination)
+ except Exception as e:
+ # We capture this here as there as nothing actually listens
+ # for this finishing functions deferred.
+ logger.warn(
+ "TX [%s] Problem in _attempt_transaction: %s",
+ destination,
+ e,
+ )
- finally:
- # We want to be *very* sure we delete this after we stop processing
- self.pending_transactions.pop(destination, None)
+ for p in pdus:
+ logger.info("Failed to send event %s to %s", p.event_id, destination)
- # Check to see if there is anything else to send.
- self._attempt_new_transaction(destination)
+ finally:
+ # We want to be *very* sure we delete this after we stop processing
+ self.pending_transactions.pop(destination, None)
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 1a50a2ec98..63d05f2531 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -19,7 +19,6 @@ from .room import (
)
from .room_member import RoomMemberHandler
from .message import MessageHandler
-from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler
from .profile import ProfileHandler
from .directory import DirectoryHandler
@@ -53,8 +52,6 @@ class Handlers(object):
self.message_handler = MessageHandler(hs)
self.room_creation_handler = RoomCreationHandler(hs)
self.room_member_handler = RoomMemberHandler(hs)
- self.event_stream_handler = EventStreamHandler(hs)
- self.event_handler = EventHandler(hs)
self.federation_handler = FederationHandler(hs)
self.profile_handler = ProfileHandler(hs)
self.directory_handler = DirectoryHandler(hs)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 2e138f328f..a582d6334b 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -70,11 +70,11 @@ class AuthHandler(BaseHandler):
self.ldap_uri = hs.config.ldap_uri
self.ldap_start_tls = hs.config.ldap_start_tls
self.ldap_base = hs.config.ldap_base
- self.ldap_filter = hs.config.ldap_filter
self.ldap_attributes = hs.config.ldap_attributes
if self.ldap_mode == LDAPMode.SEARCH:
self.ldap_bind_dn = hs.config.ldap_bind_dn
self.ldap_bind_password = hs.config.ldap_bind_password
+ self.ldap_filter = hs.config.ldap_filter
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.device_handler = hs.get_device_handler()
@@ -660,7 +660,7 @@ class AuthHandler(BaseHandler):
else:
logger.warn(
"ldap registration failed: unexpected (%d!=1) amount of results",
- len(result)
+ len(conn.response)
)
defer.returnValue(False)
@@ -719,13 +719,14 @@ class AuthHandler(BaseHandler):
return macaroon.serialize()
def validate_short_term_login_token_and_get_user_id(self, login_token):
+ auth_api = self.hs.get_auth()
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
- auth_api = self.hs.get_auth()
- auth_api.validate_macaroon(macaroon, "login", True)
- return self.get_user_from_macaroon(macaroon)
- except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
- raise AuthError(401, "Invalid token", errcode=Codes.UNKNOWN_TOKEN)
+ user_id = auth_api.get_user_id_from_macaroon(macaroon)
+ auth_api.validate_macaroon(macaroon, "login", True, user_id)
+ return user_id
+ except Exception:
+ raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
@@ -736,16 +737,6 @@ class AuthHandler(BaseHandler):
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
- def get_user_from_macaroon(self, macaroon):
- user_prefix = "user_id = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(user_prefix):
- return caveat.caveat_id[len(user_prefix):]
- raise AuthError(
- self.INVALID_TOKEN_HTTP_STATUS, "No user_id found in token",
- errcode=Codes.UNKNOWN_TOKEN
- )
-
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
password_hash = self.hash(newpassword)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 618cb53629..ff6bb475b5 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -249,7 +249,7 @@ class FederationHandler(BaseHandler):
if ev.type != EventTypes.Member:
continue
try:
- domain = UserID.from_string(ev.state_key).domain
+ domain = get_domain_from_id(ev.state_key)
except:
continue
@@ -1093,16 +1093,17 @@ class FederationHandler(BaseHandler):
)
if event:
- # FIXME: This is a temporary work around where we occasionally
- # return events slightly differently than when they were
- # originally signed
- event.signatures.update(
- compute_event_signature(
- event,
- self.hs.hostname,
- self.hs.config.signing_key[0]
+ if self.hs.is_mine_id(event.event_id):
+ # FIXME: This is a temporary work around where we occasionally
+ # return events slightly differently than when they were
+ # originally signed
+ event.signatures.update(
+ compute_event_signature(
+ event,
+ self.hs.hostname,
+ self.hs.config.signing_key[0]
+ )
)
- )
if do_auth:
in_room = yield self.auth.check_host_in_room(
@@ -1112,6 +1113,12 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
+ events = yield self._filter_events_for_server(
+ origin, event.room_id, [event]
+ )
+
+ event = events[0]
+
defer.returnValue(event)
else:
defer.returnValue(None)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6b70fa3817..6a1fe76c88 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -503,7 +503,7 @@ class PresenceHandler(object):
defer.returnValue(states)
@defer.inlineCallbacks
- def _get_interested_parties(self, states):
+ def _get_interested_parties(self, states, calculate_remote_hosts=True):
"""Given a list of states return which entities (rooms, users, servers)
are interested in the given states.
@@ -526,14 +526,15 @@ class PresenceHandler(object):
users_to_states.setdefault(state.user_id, []).append(state)
hosts_to_states = {}
- for room_id, states in room_ids_to_states.items():
- local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
- if not local_states:
- continue
+ if calculate_remote_hosts:
+ for room_id, states in room_ids_to_states.items():
+ local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
+ if not local_states:
+ continue
- hosts = yield self.store.get_joined_hosts_for_room(room_id)
- for host in hosts:
- hosts_to_states.setdefault(host, []).extend(local_states)
+ hosts = yield self.store.get_joined_hosts_for_room(room_id)
+ for host in hosts:
+ hosts_to_states.setdefault(host, []).extend(local_states)
for user_id, states in users_to_states.items():
local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
@@ -565,6 +566,16 @@ class PresenceHandler(object):
self._push_to_remotes(hosts_to_states)
+ @defer.inlineCallbacks
+ def notify_for_states(self, state, stream_id):
+ parties = yield self._get_interested_parties([state])
+ room_ids_to_states, users_to_states, hosts_to_states = parties
+
+ self.notifier.on_new_event(
+ "presence_key", stream_id, rooms=room_ids_to_states.keys(),
+ users=[UserID.from_string(u) for u in users_to_states.keys()]
+ )
+
def _push_to_remotes(self, hosts_to_states):
"""Sends state updates to remote servers.
@@ -672,7 +683,7 @@ class PresenceHandler(object):
])
@defer.inlineCallbacks
- def set_state(self, target_user, state):
+ def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
status_msg = state.get("status_msg", None)
@@ -689,10 +700,13 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id)
new_fields = {
- "state": presence,
- "status_msg": status_msg if presence != PresenceState.OFFLINE else None
+ "state": presence
}
+ if not ignore_status_msg:
+ msg = status_msg if presence != PresenceState.OFFLINE else None
+ new_fields["status_msg"] = msg
+
if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec()
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8cec8fc4ed..4709112a0c 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -141,7 +141,7 @@ class RoomMemberHandler(BaseHandler):
third_party_signed=None,
ratelimit=True,
):
- key = (target, room_id,)
+ key = (room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c3589534f8..f93093dd85 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -155,9 +155,7 @@ class MatrixFederationHttpClient(object):
time_out=timeout / 1000. if timeout else 60,
)
- response = yield preserve_context_over_fn(
- send_request,
- )
+ response = yield preserve_context_over_fn(send_request)
log_result = "%d %s" % (response.code, response.phrase,)
break
diff --git a/synapse/replication/resource.py b/synapse/replication/resource.py
index 8c2d487ff4..84993b33b3 100644
--- a/synapse/replication/resource.py
+++ b/synapse/replication/resource.py
@@ -41,6 +41,7 @@ STREAM_NAMES = (
("push_rules",),
("pushers",),
("state",),
+ ("caches",),
)
@@ -70,6 +71,7 @@ class ReplicationResource(Resource):
* "backfill": Old events that have been backfilled from other servers.
* "push_rules": Per user changes to push rules.
* "pushers": Per user changes to their pushers.
+ * "caches": Cache invalidations.
The API takes two additional query parameters:
@@ -129,6 +131,7 @@ class ReplicationResource(Resource):
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
+ caches_token = self.store.get_cache_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@@ -140,6 +143,7 @@ class ReplicationResource(Resource):
push_rules_token,
pushers_token,
state_token,
+ caches_token,
))
@request_handler()
@@ -188,6 +192,7 @@ class ReplicationResource(Resource):
yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
+ yield self.caches(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
logger.info("Replicated %d rows", writer.total)
@@ -379,6 +384,20 @@ class ReplicationResource(Resource):
"position", "type", "state_key", "event_id"
))
+ @defer.inlineCallbacks
+ def caches(self, writer, current_token, limit, request_streams):
+ current_position = current_token.caches
+
+ caches = request_streams.get("caches")
+
+ if caches is not None:
+ updated_caches = yield self.store.get_all_updated_caches(
+ caches, current_position, limit
+ )
+ writer.write_header_and_rows("caches", updated_caches, (
+ "position", "cache_func", "keys", "invalidation_ts"
+ ))
+
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
@@ -407,7 +426,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
- "push_rules", "pushers", "state"
+ "push_rules", "pushers", "state", "caches",
))):
__slots__ = []
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 46e43ce1c7..d839d169ab 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,15 +14,43 @@
# limitations under the License.
from synapse.storage._base import SQLBaseStore
+from synapse.storage.engines import PostgresEngine
from twisted.internet import defer
+from ._slaved_id_tracker import SlavedIdTracker
+
+import logging
+
+logger = logging.getLogger(__name__)
+
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
+ if isinstance(self.database_engine, PostgresEngine):
+ self._cache_id_gen = SlavedIdTracker(
+ db_conn, "cache_invalidation_stream", "stream_id",
+ )
+ else:
+ self._cache_id_gen = None
def stream_positions(self):
- return {}
+ pos = {}
+ if self._cache_id_gen:
+ pos["caches"] = self._cache_id_gen.get_current_token()
+ return pos
def process_replication(self, result):
+ stream = result.get("caches")
+ if stream:
+ for row in stream["rows"]:
+ (
+ position, cache_func, keys, invalidation_ts,
+ ) = row
+
+ try:
+ getattr(self, cache_func).invalidate(tuple(keys))
+ except AttributeError:
+ logger.warn("Got unexpected cache_func: %r", cache_func)
+ self._cache_id_gen.advance(int(stream["position"]))
return defer.succeed(None)
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 5fbe3a303a..7301d885f2 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -20,4 +20,4 @@ from synapse.storage.directory import DirectoryStore
class DirectoryStore(BaseSlavedStore):
get_aliases_for_room = DirectoryStore.__dict__[
"get_aliases_for_room"
- ].orig
+ ]
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index b0cb31a448..af21661d7c 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -28,6 +28,10 @@ logger = logging.getLogger(__name__)
class WhoisRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
+ def __init__(self, hs):
+ super(WhoisRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
@@ -82,6 +86,10 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
"/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
)
+ def __init__(self, hs):
+ super(PurgeHistoryRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
requester = yield self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index 96b49b01f2..c2a8447860 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -57,7 +57,6 @@ class ClientV1RestServlet(RestServlet):
hs (synapse.server.HomeServer):
"""
self.hs = hs
- self.handlers = hs.get_handlers()
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_v1auth()
self.txns = HttpTransactionStore()
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 8ac09419dc..09d0831594 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -36,6 +36,10 @@ def register_servlets(hs, http_server):
class ClientDirectoryServer(ClientV1RestServlet):
PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
+ def __init__(self, hs):
+ super(ClientDirectoryServer, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
@@ -146,6 +150,7 @@ class ClientDirectoryListServer(ClientV1RestServlet):
def __init__(self, hs):
super(ClientDirectoryListServer, self).__init__(hs)
self.store = hs.get_datastore()
+ self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 498bb9e18a..701b6f549b 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -32,6 +32,10 @@ class EventStreamRestServlet(ClientV1RestServlet):
DEFAULT_LONGPOLL_TIME_MS = 30000
+ def __init__(self, hs):
+ super(EventStreamRestServlet, self).__init__(hs)
+ self.event_stream_handler = hs.get_event_stream_handler()
+
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(
@@ -46,7 +50,6 @@ class EventStreamRestServlet(ClientV1RestServlet):
if "room_id" in request.args:
room_id = request.args["room_id"][0]
- handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if "timeout" in request.args:
@@ -57,7 +60,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args
- chunk = yield handler.get_stream(
+ chunk = yield self.event_stream_handler.get_stream(
requester.user.to_string(),
pagin_config,
timeout=timeout,
@@ -80,12 +83,12 @@ class EventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(EventRestServlet, self).__init__(hs)
self.clock = hs.get_clock()
+ self.event_handler = hs.get_event_handler()
@defer.inlineCallbacks
def on_GET(self, request, event_id):
requester = yield self.auth.get_user_by_req(request)
- handler = self.handlers.event_handler
- event = yield handler.get_event(requester.user, event_id)
+ event = yield self.event_handler.get_event(requester.user, event_id)
time_now = self.clock.time_msec()
if event:
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 36c3520567..113a49e539 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -23,6 +23,10 @@ from .base import ClientV1RestServlet, client_path_patterns
class InitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/initialSync$")
+ def __init__(self, hs):
+ super(InitialSyncRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 92fcae674a..6c0eec8fb3 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -54,12 +54,9 @@ class LoginRestServlet(ClientV1RestServlet):
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled
- self.cas_server_url = hs.config.cas_server_url
- self.cas_required_attributes = hs.config.cas_required_attributes
- self.servername = hs.config.server_name
- self.http_client = hs.get_simple_http_client()
self.auth_handler = self.hs.get_auth_handler()
self.device_handler = self.hs.get_device_handler()
+ self.handlers = hs.get_handlers()
def on_GET(self, request):
flows = []
@@ -110,17 +107,6 @@ class LoginRestServlet(ClientV1RestServlet):
LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission)
defer.returnValue(result)
- # TODO Delete this after all CAS clients switch to token login instead
- elif self.cas_enabled and (login_submission["type"] ==
- LoginRestServlet.CAS_TYPE):
- uri = "%s/proxyValidate" % (self.cas_server_url,)
- args = {
- "ticket": login_submission["ticket"],
- "service": login_submission["service"]
- }
- body = yield self.http_client.get_raw(uri, args)
- result = yield self.do_cas_login(body)
- defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
defer.returnValue(result)
@@ -191,51 +177,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
- # TODO Delete this after all CAS clients switch to token login instead
- @defer.inlineCallbacks
- def do_cas_login(self, cas_response_body):
- user, attributes = self.parse_cas_response(cas_response_body)
-
- for required_attribute, required_value in self.cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
-
- user_id = UserID.create(user, self.hs.hostname).to_string()
- auth_handler = self.auth_handler
- registered_user_id = yield auth_handler.check_user_exists(user_id)
- if registered_user_id:
- access_token, refresh_token = (
- yield auth_handler.get_login_tuple_for_user_id(
- registered_user_id
- )
- )
- result = {
- "user_id": registered_user_id, # may have changed
- "access_token": access_token,
- "refresh_token": refresh_token,
- "home_server": self.hs.hostname,
- }
-
- else:
- user_id, access_token = (
- yield self.handlers.registration_handler.register(localpart=user)
- )
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }
-
- defer.returnValue((200, result))
-
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
token = login_submission.get("token", None)
@@ -293,33 +234,6 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
- # TODO Delete this after all CAS clients switch to token login instead
- def parse_cas_response(self, cas_response_body):
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not root[0].tag.endswith("authenticationSuccess"):
- raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- attributes = {}
- for attribute in child:
- # ElementTree library expands the namespace in attribute tags
- # to the full URL of the namespace.
- # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
- # We don't care about namespace here and it will always be encased in
- # curly braces, so we remove them.
- if "}" in attribute.tag:
- attributes[attribute.tag.split("}")[1]] = attribute.text
- else:
- attributes[attribute.tag] = attribute.text
- if user is None or attributes is None:
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-
- return (user, attributes)
-
def _register_device(self, user_id, login_submission):
"""Register a device for a user.
@@ -347,6 +261,7 @@ class SAML2RestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(SAML2RestServlet, self).__init__(hs)
self.sp_config = hs.config.saml2_config_path
+ self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -384,18 +299,6 @@ class SAML2RestServlet(ClientV1RestServlet):
defer.returnValue((200, {"status": "not_authenticated"}))
-# TODO Delete this after all CAS clients switch to token login instead
-class CasRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/cas", releases=())
-
- def __init__(self, hs):
- super(CasRestServlet, self).__init__(hs)
- self.cas_server_url = hs.config.cas_server_url
-
- def on_GET(self, request):
- return (200, {"serverUrl": self.cas_server_url})
-
-
class CasRedirectServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
@@ -427,6 +330,8 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
+ self.auth_handler = hs.get_auth_handler()
+ self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -479,30 +384,39 @@ class CasTicketServlet(ClientV1RestServlet):
return urlparse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not root[0].tag.endswith("authenticationSuccess"):
- raise LoginError(401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED)
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- attributes = {}
- for attribute in child:
- # ElementTree library expands the namespace in attribute tags
- # to the full URL of the namespace.
- # See (https://docs.python.org/2/library/xml.etree.elementtree.html)
- # We don't care about namespace here and it will always be encased in
- # curly braces, so we remove them.
- if "}" in attribute.tag:
- attributes[attribute.tag.split("}")[1]] = attribute.text
- else:
- attributes[attribute.tag] = attribute.text
- if user is None or attributes is None:
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
-
- return (user, attributes)
+ user = None
+ attributes = None
+ try:
+ root = ET.fromstring(cas_response_body)
+ if not root.tag.endswith("serviceResponse"):
+ raise Exception("root of CAS response is not serviceResponse")
+ success = (root[0].tag.endswith("authenticationSuccess"))
+ for child in root[0]:
+ if child.tag.endswith("user"):
+ user = child.text
+ if child.tag.endswith("attributes"):
+ attributes = {}
+ for attribute in child:
+ # ElementTree library expands the namespace in
+ # attribute tags to the full URL of the namespace.
+ # We don't care about namespace here and it will always
+ # be encased in curly braces, so we remove them.
+ tag = attribute.tag
+ if "}" in tag:
+ tag = tag.split("}")[1]
+ attributes[tag] = attribute.text
+ if user is None:
+ raise Exception("CAS response does not contain user")
+ if attributes is None:
+ raise Exception("CAS response does not contain attributes")
+ except Exception:
+ logger.error("Error parsing CAS response", exc_info=1)
+ raise LoginError(401, "Invalid CAS response",
+ errcode=Codes.UNAUTHORIZED)
+ if not success:
+ raise LoginError(401, "Unsuccessful CAS response",
+ errcode=Codes.UNAUTHORIZED)
+ return user, attributes
def register_servlets(hs, http_server):
@@ -512,5 +426,3 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
- CasRestServlet(hs).register(http_server)
- # TODO PasswordResetRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 65c4e2ebef..355e82474b 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -24,6 +24,10 @@ from synapse.http.servlet import parse_json_object_from_request
class ProfileDisplaynameRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
+ def __init__(self, hs):
+ super(ProfileDisplaynameRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
@@ -62,6 +66,10 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
class ProfileAvatarURLRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
+ def __init__(self, hs):
+ super(ProfileAvatarURLRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
@@ -99,6 +107,10 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
class ProfileRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
+ def __init__(self, hs):
+ super(ProfileRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, user_id):
user = UserID.from_string(user_id)
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 2383b9df86..71d58c8e8d 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -65,6 +65,7 @@ class RegisterRestServlet(ClientV1RestServlet):
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth_handler = hs.get_auth_handler()
+ self.handlers = hs.get_handlers()
def on_GET(self, request):
if self.hs.config.enable_registration_captcha:
@@ -383,6 +384,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
super(CreateUserRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
self.direct_user_creation_max_duration = hs.config.user_creation_max_duration
+ self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_POST(self, request):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 866a1e9120..89c3895118 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -35,6 +35,10 @@ logger = logging.getLogger(__name__)
class RoomCreateRestServlet(ClientV1RestServlet):
# No PATTERN; we have custom dispatch rules here
+ def __init__(self, hs):
+ super(RoomCreateRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
@@ -82,6 +86,10 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events
class RoomStateEventRestServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(RoomStateEventRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
# /room/$roomid/state/$eventtype
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
@@ -166,6 +174,10 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(RoomSendEventRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
@@ -210,6 +222,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins
class JoinRoomAliasServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(JoinRoomAliasServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@@ -296,6 +311,10 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
class RoomMemberListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
+ def __init__(self, hs):
+ super(RoomMemberListRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
@@ -322,6 +341,10 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
class RoomMessageListRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
+ def __init__(self, hs):
+ super(RoomMessageListRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -351,6 +374,10 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
class RoomStateRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
+ def __init__(self, hs):
+ super(RoomStateRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -368,6 +395,10 @@ class RoomStateRestServlet(ClientV1RestServlet):
class RoomInitialSyncRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
+ def __init__(self, hs):
+ super(RoomInitialSyncRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
@@ -388,6 +419,7 @@ class RoomEventContext(ClientV1RestServlet):
def __init__(self, hs):
super(RoomEventContext, self).__init__(hs)
self.clock = hs.get_clock()
+ self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@@ -424,6 +456,10 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(RoomForgetRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server)
@@ -462,6 +498,10 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(RoomMembershipRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
@@ -542,6 +582,10 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
class RoomRedactEventRestServlet(ClientV1RestServlet):
+ def __init__(self, hs):
+ super(RoomRedactEventRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
register_txn_path(self, PATTERNS, http_server)
@@ -624,6 +668,10 @@ class SearchRestServlet(ClientV1RestServlet):
"/search$"
)
+ def __init__(self, hs):
+ super(SearchRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
@defer.inlineCallbacks
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 43d8e0bf39..b11acdbea7 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -146,7 +146,7 @@ class SyncRestServlet(RestServlet):
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
- yield self.presence_handler.set_state(user, {"presence": set_presence})
+ yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
context = yield self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence,
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 7209d5a37d..9fe2013657 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,6 +15,7 @@
from synapse.http.server import request_handler, respond_with_json_bytes
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes
+from synapse.crypto.keyring import KeyLookupError
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -210,9 +211,10 @@ class RemoteKey(Resource):
yield self.keyring.get_server_verify_key_v2_direct(
server_name, key_ids
)
+ except KeyLookupError as e:
+ logger.info("Failed to fetch key: %s", e)
except:
logger.exception("Failed to get key for %r", server_name)
- pass
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
)
diff --git a/synapse/server.py b/synapse/server.py
index 6bb4988309..af3246504b 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -41,6 +41,7 @@ from synapse.handlers.presence import PresenceHandler
from synapse.handlers.room import RoomListHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
+from synapse.handlers.events import EventHandler, EventStreamHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
@@ -94,6 +95,8 @@ class HomeServer(object):
'auth_handler',
'device_handler',
'e2e_keys_handler',
+ 'event_handler',
+ 'event_stream_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
@@ -214,6 +217,12 @@ class HomeServer(object):
def build_application_service_handler(self):
return ApplicationServicesHandler(self)
+ def build_event_handler(self):
+ return EventHandler(self)
+
+ def build_event_stream_handler(self):
+ return EventStreamHandler(self)
+
def build_event_sources(self):
return EventSources(self)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index c0aa868c4f..9570df5537 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,3 +1,4 @@
+import synapse.api.auth
import synapse.handlers
import synapse.handlers.auth
import synapse.handlers.device
@@ -6,6 +7,9 @@ import synapse.storage
import synapse.state
class HomeServer(object):
+ def get_auth(self) -> synapse.api.auth.Auth:
+ pass
+
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
pass
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 73fb334dd6..7efc5bfeef 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -50,6 +50,7 @@ from .openid import OpenIdStore
from .client_ips import ClientIpStore
from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
+from .engines import PostgresEngine
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -123,6 +124,13 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")],
)
+ if isinstance(self.database_engine, PostgresEngine):
+ self._cache_id_gen = StreamIdGenerator(
+ db_conn, "cache_invalidation_stream", "stream_id",
+ )
+ else:
+ self._cache_id_gen = None
+
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 0117fdc639..b0923a9cad 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -19,6 +19,7 @@ from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.dictionary_cache import DictionaryCache
from synapse.util.caches.descriptors import Cache
from synapse.util.caches import intern_dict
+from synapse.storage.engines import PostgresEngine
import synapse.metrics
@@ -305,13 +306,14 @@ class SQLBaseStore(object):
func, *args, **kwargs
)
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(
- inner_func, *args, **kwargs
- )
-
- for after_callback, after_args in after_callbacks:
- after_callback(*after_args)
+ try:
+ with PreserveLoggingContext():
+ result = yield self._db_pool.runWithConnection(
+ inner_func, *args, **kwargs
+ )
+ finally:
+ for after_callback, after_args in after_callbacks:
+ after_callback(*after_args)
defer.returnValue(result)
@defer.inlineCallbacks
@@ -860,6 +862,58 @@ class SQLBaseStore(object):
return cache, min_val
+ def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ """Invalidates the cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+
+ This should only be used to invalidate caches where slaves won't
+ otherwise know from other replication streams that the cache should
+ be invalidated.
+ """
+ txn.call_after(cache_func.invalidate, keys)
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # get_next() returns a context manager which is designed to wrap
+ # the transaction. However, we want to only get an ID when we want
+ # to use it, here, so we need to call __enter__ manually, and have
+ # __exit__ called after the transaction finishes.
+ ctx = self._cache_id_gen.get_next()
+ stream_id = ctx.__enter__()
+ txn.call_after(ctx.__exit__, None, None, None)
+
+ self._simple_insert_txn(
+ txn,
+ table="cache_invalidation_stream",
+ values={
+ "stream_id": stream_id,
+ "cache_func": cache_func.__name__,
+ "keys": list(keys),
+ "invalidation_ts": self.clock.time_msec(),
+ }
+ )
+
+ def get_all_updated_caches(self, last_id, current_id, limit):
+ def get_all_updated_caches_txn(txn):
+ # We purposefully don't bound by the current token, as we want to
+ # send across cache invalidations as quickly as possible. Cache
+ # invalidations are idempotent, so duplicates are fine.
+ sql = (
+ "SELECT stream_id, cache_func, keys, invalidation_ts"
+ " FROM cache_invalidation_stream"
+ " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, limit,))
+ return txn.fetchall()
+ return self.runInteraction(
+ "get_all_updated_caches", get_all_updated_caches_txn
+ )
+
+ def get_cache_stream_token(self):
+ if self._cache_id_gen:
+ return self._cache_id_gen.get_current_token()
+ else:
+ return 0
+
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index ef231a04dc..9caaf81f2c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -82,32 +82,39 @@ class DirectoryStore(SQLBaseStore):
Returns:
Deferred
"""
- try:
- yield self._simple_insert(
+ def alias_txn(txn):
+ self._simple_insert_txn(
+ txn,
"room_aliases",
{
"room_alias": room_alias.to_string(),
"room_id": room_id,
"creator": creator,
},
- desc="create_room_alias_association",
- )
- except self.database_engine.module.IntegrityError:
- raise SynapseError(
- 409, "Room alias %s already exists" % room_alias.to_string()
)
- for server in servers:
- # TODO(erikj): Fix this to bulk insert
- yield self._simple_insert(
- "room_alias_servers",
- {
+ self._simple_insert_many_txn(
+ txn,
+ table="room_alias_servers",
+ values=[{
"room_alias": room_alias.to_string(),
"server": server,
- },
- desc="create_room_alias_association",
+ } for server in servers],
)
- self.get_aliases_for_room.invalidate((room_id,))
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_aliases_for_room, (room_id,)
+ )
+
+ try:
+ ret = yield self.runInteraction(
+ "create_room_alias_association", alias_txn
+ )
+ except self.database_engine.module.IntegrityError:
+ raise SynapseError(
+ 409, "Room alias %s already exists" % room_alias.to_string()
+ )
+ defer.returnValue(ret)
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 8801669a6b..b94ce7bea1 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 33
+SCHEMA_VERSION = 34
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index d03f7c541e..21d0696640 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -189,18 +189,30 @@ class PresenceStore(SQLBaseStore):
desc="add_presence_list_pending",
)
- @defer.inlineCallbacks
def set_presence_list_accepted(self, observer_localpart, observed_userid):
- result = yield self._simple_update_one(
- table="presence_list",
- keyvalues={"user_id": observer_localpart,
- "observed_user_id": observed_userid},
- updatevalues={"accepted": True},
- desc="set_presence_list_accepted",
+ def update_presence_list_txn(txn):
+ result = self._simple_update_one_txn(
+ txn,
+ table="presence_list",
+ keyvalues={
+ "user_id": observer_localpart,
+ "observed_user_id": observed_userid
+ },
+ updatevalues={"accepted": True},
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_presence_list_accepted, (observer_localpart,)
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_presence_list_observers_accepted, (observed_userid,)
+ )
+
+ return result
+
+ return self.runInteraction(
+ "set_presence_list_accepted", update_presence_list_txn,
)
- self.get_presence_list_accepted.invalidate((observer_localpart,))
- self.get_presence_list_observers_accepted.invalidate((observed_userid,))
- defer.returnValue(result)
def get_presence_list(self, observer_localpart, accepted=None):
if accepted:
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8bd693be72..a422ddf633 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -277,7 +277,6 @@ class RoomMemberStore(SQLBaseStore):
user_id, membership_list=[Membership.JOIN],
)
- @defer.inlineCallbacks
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -292,10 +291,13 @@ class RoomMemberStore(SQLBaseStore):
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
- yield self.runInteraction("forget_membership", f)
- self.was_forgotten_at.invalidate_all()
- self.who_forgot_in_room.invalidate_all()
- self.did_forget.invalidate((user_id, room_id))
+
+ txn.call_after(self.was_forgotten_at.invalidate_all)
+ txn.call_after(self.did_forget.invalidate, (user_id, room_id))
+ self._invalidate_cache_and_stream(
+ txn, self.who_forgot_in_room, (room_id,)
+ )
+ return self.runInteraction("forget_membership", f)
@cachedInlineCallbacks(num_args=2)
def did_forget(self, user_id, room_id):
diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/schema/delta/34/cache_stream.py
new file mode 100644
index 0000000000..3b63a1562d
--- /dev/null
+++ b/synapse/storage/schema/delta/34/cache_stream.py
@@ -0,0 +1,46 @@
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.engines import PostgresEngine
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+# This stream is used to notify replication slaves that some caches have
+# been invalidated that they cannot infer from the other streams.
+CREATE_TABLE = """
+CREATE TABLE cache_invalidation_stream (
+ stream_id BIGINT,
+ cache_func TEXT,
+ keys TEXT[],
+ invalidation_ts BIGINT
+);
+
+CREATE INDEX cache_invalidation_stream_id ON cache_invalidation_stream(stream_id);
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if not isinstance(database_engine, PostgresEngine):
+ return
+
+ for statement in get_statements(CREATE_TABLE.splitlines()):
+ cur.execute(statement)
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 5316259d15..7a87045f87 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -317,7 +317,6 @@ def preserve_fn(f):
def g(*args, **kwargs):
with PreserveLoggingContext(current):
return f(*args, **kwargs)
-
return g
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 0b944d3e63..76f301f549 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet import defer
from synapse.util.logcontext import LoggingContext
import synapse.metrics
+from functools import wraps
import logging
@@ -47,6 +49,18 @@ block_db_txn_duration = metrics.register_distribution(
)
+def measure_func(name):
+ def wrapper(func):
+ @wraps(func)
+ @defer.inlineCallbacks
+ def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, name):
+ r = yield func(self, *args, **kwargs)
+ defer.returnValue(r)
+ return measured_func
+ return wrapper
+
+
class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
@@ -64,7 +78,6 @@ class Measure(object):
self.start = self.clock.time_msec()
self.start_context = LoggingContext.current_context()
if not self.start_context:
- logger.warn("Entered Measure without log context: %s", self.name)
self.start_context = LoggingContext("Measure")
self.start_context.__enter__()
self.created_context = True
@@ -85,7 +98,7 @@ class Measure(object):
if context != self.start_context:
logger.warn(
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
- context, self.start_context, self.name
+ self.start_context, context, self.name
)
return
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 21077cbe9a..4a8cd19acf 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -14,11 +14,13 @@
# limitations under the License.
import pymacaroons
+from twisted.internet import defer
+import synapse
+import synapse.api.errors
from synapse.handlers.auth import AuthHandler
from tests import unittest
from tests.utils import setup_test_homeserver
-from twisted.internet import defer
class AuthHandlers(object):
@@ -31,11 +33,12 @@ class AuthTestCase(unittest.TestCase):
def setUp(self):
self.hs = yield setup_test_homeserver(handlers=None)
self.hs.handlers = AuthHandlers(self.hs)
+ self.auth_handler = self.hs.handlers.auth_handler
def test_token_is_a_macaroon(self):
self.hs.config.macaroon_secret_key = "this key is a huge secret"
- token = self.hs.handlers.auth_handler.generate_access_token("some_user")
+ token = self.auth_handler.generate_access_token("some_user")
# Check that we can parse the thing with pymacaroons
macaroon = pymacaroons.Macaroon.deserialize(token)
# The most basic of sanity checks
@@ -46,7 +49,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.macaroon_secret_key = "this key is a massive secret"
self.hs.clock.now = 5000
- token = self.hs.handlers.auth_handler.generate_access_token("a_user")
+ token = self.auth_handler.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat):
@@ -67,3 +70,46 @@ class AuthTestCase(unittest.TestCase):
v.satisfy_general(verify_type)
v.satisfy_general(verify_expiry)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
+
+ def test_short_term_login_token_gives_user_id(self):
+ self.hs.clock.now = 1000
+
+ token = self.auth_handler.generate_short_term_login_token(
+ "a_user", 5000
+ )
+
+ self.assertEqual(
+ "a_user",
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
+ )
+ )
+
+ # when we advance the clock, the token should be rejected
+ self.hs.clock.now = 6000
+ with self.assertRaises(synapse.api.errors.AuthError):
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
+ )
+
+ def test_short_term_login_token_cannot_replace_user_id(self):
+ token = self.auth_handler.generate_short_term_login_token(
+ "a_user", 5000
+ )
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+
+ self.assertEqual(
+ "a_user",
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
+ )
+
+ # add another "user_id" caveat, which might allow us to override the
+ # user_id.
+ macaroon.add_first_party_caveat("user_id = b_user")
+
+ with self.assertRaises(synapse.api.errors.AuthError):
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
|