diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 7837f1c252..dd8124dbb9 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -347,10 +347,12 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def on_context_state_request(self, context, event_id):
+ def on_context_state_request(self, origin, context, event_id):
if event_id:
pdus = yield self.handler.get_state_for_pdu(
- event_id
+ origin,
+ context,
+ event_id,
)
else:
raise NotImplementedError("Specify an event")
@@ -365,8 +367,8 @@ class ReplicationLayer(object):
@defer.inlineCallbacks
@log_function
- def on_pdu_request(self, event_id):
- pdu = yield self._get_persisted_pdu(event_id)
+ def on_pdu_request(self, origin, event_id):
+ pdu = yield self._get_persisted_pdu(origin, event_id)
if pdu:
defer.returnValue(
@@ -499,13 +501,13 @@ class ReplicationLayer(object):
defer.returnValue(Pdu(**pdu_dict))
@log_function
- def _get_persisted_pdu(self, event_id):
+ def _get_persisted_pdu(self, origin, event_id):
""" Get a PDU from the database with given origin and id.
Returns:
Deferred: Results in a `Pdu`.
"""
- return self.handler.get_persisted_pdu(event_id)
+ return self.handler.get_persisted_pdu(origin, event_id)
def _transaction_from_pdus(self, pdu_list):
"""Returns a new Transaction containing the given PDUs suitable for
diff --git a/synapse/federation/transport.py b/synapse/federation/transport.py
index 92a1f4ce17..d84a44c211 100644
--- a/synapse/federation/transport.py
+++ b/synapse/federation/transport.py
@@ -390,7 +390,7 @@ class TransportLayer(object):
re.compile("^" + PREFIX + "/event/([^/]*)/$"),
self._with_authentication(
lambda origin, content, query, event_id:
- handler.on_pdu_request(event_id)
+ handler.on_pdu_request(origin, event_id)
)
)
@@ -401,6 +401,7 @@ class TransportLayer(object):
self._with_authentication(
lambda origin, content, query, context:
handler.on_context_state_request(
+ origin,
context,
query.get("event_id", [None])[0],
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 00d10609b8..587fa308c8 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -436,9 +436,13 @@ class FederationHandler(BaseHandler):
defer.returnValue(self.pdu_codec.pdu_from_event(event))
@defer.inlineCallbacks
- def get_state_for_pdu(self, event_id):
+ def get_state_for_pdu(self, origin, room_id, event_id):
yield run_on_reactor()
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+
state_groups = yield self.store.get_state_groups(
[event_id]
)
@@ -488,7 +492,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def get_persisted_pdu(self, event_id):
+ def get_persisted_pdu(self, origin, event_id):
""" Get a PDU from the database with given origin and id.
Returns:
@@ -500,6 +504,13 @@ class FederationHandler(BaseHandler):
)
if event:
+ in_room = yield self.auth.check_host_in_room(
+ event.room_id,
+ origin
+ )
+ if not in_room:
+ raise AuthError(403, "Host not in room.")
+
defer.returnValue(self.pdu_codec.pdu_from_event(event))
else:
defer.returnValue(None)
|