From 3e3f9b684e2853217d86349f289780b397afa88a Mon Sep 17 00:00:00 2001
From: Hubert Chathi <hubert@uhoreg.ca>
Date: Tue, 22 Oct 2019 22:26:30 -0400
Subject: fix unit test

---
 tests/storage/test_devices.py | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

(limited to 'tests')

diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 3cc18f9f1c..039cc79357 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -137,7 +137,9 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         """Check that an specific device ids exist in a list of device update EDUs"""
         self.assertEqual(len(device_updates), len(expected_device_ids))
 
-        received_device_ids = {update["device_id"] for update in device_updates}
+        received_device_ids = {
+            update["device_id"] for edu_type, update in device_updates
+        }
         self.assertEqual(received_device_ids, set(expected_device_ids))
 
     @defer.inlineCallbacks
-- 
cgit 1.5.1


From e7943f660add8b602ea5225060bd0d74e6440017 Mon Sep 17 00:00:00 2001
From: Brendan Abolivier <babolivier@matrix.org>
Date: Wed, 30 Oct 2019 16:15:04 +0000
Subject: Add unit tests

---
 synapse/api/filtering.py    |  2 +-
 tests/api/test_filtering.py | 51 +++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 52 insertions(+), 1 deletion(-)

(limited to 'tests')

diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index a27029c678..bd91b9f018 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -307,7 +307,7 @@ class Filter(object):
             content = event.get("content", {})
             # check if there is a string url field in the content for filtering purposes
             contains_url = isinstance(content.get("url"), text_type)
-            labels = content.get(LabelsField)
+            labels = content.get(LabelsField, [])
 
         return self.check_fields(room_id, sender, ev_type, labels, contains_url)
 
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 6ba623de13..66b3c828db 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -19,6 +19,7 @@ import jsonschema
 
 from twisted.internet import defer
 
+from synapse.api.constants import LabelsField
 from synapse.api.errors import SynapseError
 from synapse.api.filtering import Filter
 from synapse.events import FrozenEvent
@@ -95,6 +96,8 @@ class FilteringTestCase(unittest.TestCase):
                         "types": ["m.room.message"],
                         "not_rooms": ["!726s6s6q:example.com"],
                         "not_senders": ["@spam:example.com"],
+                        "org.matrix.labels": ["#fun"],
+                        "org.matrix.not_labels": ["#work"],
                     },
                     "ephemeral": {
                         "types": ["m.receipt", "m.typing"],
@@ -320,6 +323,54 @@ class FilteringTestCase(unittest.TestCase):
         )
         self.assertFalse(Filter(definition).check(event))
 
+    def test_filter_labels(self):
+        definition = {"org.matrix.labels": ["#fun"]}
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!secretbase:unknown",
+            content={
+                LabelsField: ["#fun"]
+            },
+        )
+
+        self.assertTrue(Filter(definition).check(event))
+
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!secretbase:unknown",
+            content={
+                LabelsField: ["#notfun"]
+            },
+        )
+
+        self.assertFalse(Filter(definition).check(event))
+
+    def test_filter_not_labels(self):
+        definition = {"org.matrix.not_labels": ["#fun"]}
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!secretbase:unknown",
+            content={
+                LabelsField: ["#fun"]
+            },
+        )
+
+        self.assertFalse(Filter(definition).check(event))
+
+        event = MockEvent(
+            sender="@foo:bar",
+            type="m.room.message",
+            room_id="!secretbase:unknown",
+            content={
+                LabelsField: ["#notfun"]
+            },
+        )
+
+        self.assertTrue(Filter(definition).check(event))
+
     @defer.inlineCallbacks
     def test_filter_presence_match(self):
         user_filter_json = {"presence": {"types": ["m.*"]}}
-- 
cgit 1.5.1


From 395683add1d569c0fdfd83d279551a3ba926f4d5 Mon Sep 17 00:00:00 2001
From: Brendan Abolivier <babolivier@matrix.org>
Date: Wed, 30 Oct 2019 16:47:37 +0000
Subject: Add integration tests for sync

---
 tests/rest/client/v1/utils.py           |  15 ++++-
 tests/rest/client/v2_alpha/test_sync.py | 112 +++++++++++++++++++++++++++++++-
 2 files changed, 122 insertions(+), 5 deletions(-)

(limited to 'tests')

diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index cdded88b7f..8ea0cb05ea 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -106,13 +106,22 @@ class RestHelper(object):
         self.auth_user_id = temp_id
 
     def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
-        if txn_id is None:
-            txn_id = "m%s" % (str(time.time()))
         if body is None:
             body = "body_text_here"
 
-        path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
         content = {"msgtype": "m.text", "body": body}
+
+        return self.send_event(
+            room_id, "m.room.message", content, txn_id, tok, expect_code
+        )
+
+    def send_event(
+        self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
+    ):
+        if txn_id is None:
+            txn_id = "m%s" % (str(time.time()))
+
+        path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id)
         if tok:
             path = path + "?access_token=%s" % tok
 
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 71895094bd..0263be010f 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -12,9 +12,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import json
 from mock import Mock
 
+from synapse.api.constants import EventTypes, LabelsField
 import synapse.rest.admin
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import sync
@@ -26,7 +27,12 @@ from tests.server import TimedOutException
 class FilterTestCase(unittest.HomeserverTestCase):
 
     user_id = "@apple:test"
-    servlets = [sync.register_servlets]
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+    ]
 
     def make_homeserver(self, reactor, clock):
 
@@ -70,6 +76,108 @@ class FilterTestCase(unittest.HomeserverTestCase):
         )
 
 
+class SyncFilterTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+        sync.register_servlets,
+    ]
+
+    def test_sync_filter_labels(self):
+        sync_filter = json.dumps(
+            {
+                "room": {
+                    "timeline": {
+                        "types": [EventTypes.Message],
+                        "org.matrix.labels": ["#fun"],
+                    }
+                }
+            }
+        )
+
+        events = self._test_sync_filter_labels(sync_filter)
+
+        self.assertEqual(len(events), 2, events)
+        self.assertEqual(events[0]["content"]["body"], "with label", events[0])
+        self.assertEqual(events[1]["content"]["body"], "with label", events[1])
+
+    def test_sync_filter_not_labels(self):
+        sync_filter = json.dumps(
+            {
+                "room": {
+                    "timeline": {
+                        "types": [EventTypes.Message],
+                        "org.matrix.not_labels": ["#fun"],
+                    }
+                }
+            }
+        )
+
+        events = self._test_sync_filter_labels(sync_filter)
+
+        self.assertEqual(len(events), 2, events)
+        self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+        self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+
+    def _test_sync_filter_labels(self, sync_filter):
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+
+        room_id = self.helper.create_room_as(user_id, tok=tok)
+
+        self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with label",
+                LabelsField: ["#fun"],
+            },
+            tok=tok,
+        )
+
+        self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "without label",
+            },
+            tok=tok,
+        )
+
+        self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with wrong label",
+                LabelsField: ["#work"],
+            },
+            tok=tok,
+        )
+
+        self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with label",
+                LabelsField: ["#fun"],
+            },
+            tok=tok,
+        )
+
+        request, channel = self.make_request(
+            "GET", "/sync?filter=%s" % sync_filter, access_token=tok
+        )
+        self.render(request)
+        self.assertEqual(channel.code, 200, channel.result)
+
+        return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
+
+
 class SyncTypingTests(unittest.HomeserverTestCase):
 
     servlets = [
-- 
cgit 1.5.1


From fe51d6cacf6e1a2da5fc3589d0bc4118342b33dd Mon Sep 17 00:00:00 2001
From: Brendan Abolivier <babolivier@matrix.org>
Date: Wed, 30 Oct 2019 17:28:41 +0000
Subject: Add more integration testing

---
 synapse/storage/data_stores/main/stream.py |  2 +-
 tests/rest/client/v2_alpha/test_sync.py    | 45 ++++++++++++++++++++++++++----
 2 files changed, 40 insertions(+), 7 deletions(-)

(limited to 'tests')

diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 907d7f20ba..cfa34ba1e7 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -872,7 +872,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         args.append(int(limit))
 
         sql = (
-            "SELECT event_id, topological_ordering, stream_ordering"
+            "SELECT DISTINCT event_id, topological_ordering, stream_ordering"
             " FROM events"
             " LEFT JOIN event_labels USING (event_id)"
             " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 0263be010f..a1aa7d87bd 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -85,6 +85,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
     ]
 
     def test_sync_filter_labels(self):
+        """Test that we can filter by a label."""
         sync_filter = json.dumps(
             {
                 "room": {
@@ -98,11 +99,12 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
 
         events = self._test_sync_filter_labels(sync_filter)
 
-        self.assertEqual(len(events), 2, events)
-        self.assertEqual(events[0]["content"]["body"], "with label", events[0])
-        self.assertEqual(events[1]["content"]["body"], "with label", events[1])
+        self.assertEqual(len(events), 2, [event["content"] for event in events])
+        self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+        self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
 
     def test_sync_filter_not_labels(self):
+        """Test that we can filter by the absence of a label."""
         sync_filter = json.dumps(
             {
                 "room": {
@@ -116,9 +118,29 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
 
         events = self._test_sync_filter_labels(sync_filter)
 
-        self.assertEqual(len(events), 2, events)
+        self.assertEqual(len(events), 3, [event["content"] for event in events])
         self.assertEqual(events[0]["content"]["body"], "without label", events[0])
         self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+        self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2])
+
+    def test_sync_filter_labels_not_labels(self):
+        """Test that we can filter by both a label and the absence of another label."""
+        sync_filter = json.dumps(
+            {
+                "room": {
+                    "timeline": {
+                        "types": [EventTypes.Message],
+                        "org.matrix.labels": ["#work"],
+                        "org.matrix.not_labels": ["#notfun"],
+                    }
+                }
+            }
+        )
+
+        events = self._test_sync_filter_labels(sync_filter)
+
+        self.assertEqual(len(events), 1, [event["content"] for event in events])
+        self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
 
     def _test_sync_filter_labels(self, sync_filter):
         user_id = self.register_user("kermit", "test")
@@ -131,7 +153,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             type=EventTypes.Message,
             content={
                 "msgtype": "m.text",
-                "body": "with label",
+                "body": "with right label",
                 LabelsField: ["#fun"],
             },
             tok=tok,
@@ -163,7 +185,18 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             type=EventTypes.Message,
             content={
                 "msgtype": "m.text",
-                "body": "with label",
+                "body": "with two wrong labels",
+                LabelsField: ["#work", "#notfun"],
+            },
+            tok=tok,
+        )
+
+        self.helper.send_event(
+            room_id=room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with right label",
                 LabelsField: ["#fun"],
             },
             tok=tok,
-- 
cgit 1.5.1


From d8c9109aeee58950f0fd4d9865836b82aa7aafb6 Mon Sep 17 00:00:00 2001
From: Brendan Abolivier <babolivier@matrix.org>
Date: Wed, 30 Oct 2019 17:48:22 +0000
Subject: Add integration tests for /messages

---
 tests/rest/client/v1/test_rooms.py | 102 ++++++++++++++++++++++++++++++++++++-
 1 file changed, 101 insertions(+), 1 deletion(-)

(limited to 'tests')

diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2f2ca74611..ba2008497e 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse
 from twisted.internet import defer
 
 import synapse.rest.admin
-from synapse.api.constants import Membership
+from synapse.api.constants import EventTypes, LabelsField, Membership
 from synapse.rest.client.v1 import login, profile, room
 
 from tests import unittest
@@ -811,6 +811,106 @@ class RoomMessageListTestCase(RoomBase):
         self.assertTrue("chunk" in channel.json_body)
         self.assertTrue("end" in channel.json_body)
 
+    def test_filter_labels(self):
+        """Test that we can filter by a label."""
+        message_filter = json.dumps({
+            "types": [EventTypes.Message],
+            "org.matrix.labels": ["#fun"],
+        })
+
+        events = self._test_filter_labels(message_filter)
+
+        self.assertEqual(len(events), 2, [event["content"] for event in events])
+        self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
+        self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
+
+    def test_filter_not_labels(self):
+        """Test that we can filter by the absence of a label."""
+        message_filter = json.dumps({
+            "types": [EventTypes.Message],
+            "org.matrix.not_labels": ["#fun"],
+        })
+
+        events = self._test_filter_labels(message_filter)
+
+        self.assertEqual(len(events), 3, [event["content"] for event in events])
+        self.assertEqual(events[0]["content"]["body"], "without label", events[0])
+        self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
+        self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2])
+
+    def test_filter_labels_not_labels(self):
+        """Test that we can filter by both a label and the absence of another label."""
+        sync_filter = json.dumps({
+            "types": [EventTypes.Message],
+            "org.matrix.labels": ["#work"],
+            "org.matrix.not_labels": ["#notfun"],
+        })
+
+        events = self._test_filter_labels(sync_filter)
+
+        self.assertEqual(len(events), 1, [event["content"] for event in events])
+        self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
+
+    def _test_filter_labels(self, message_filter):
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with right label",
+                LabelsField: ["#fun"],
+            }
+        )
+
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "without label",
+            }
+        )
+
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with wrong label",
+                LabelsField: ["#work"],
+            }
+        )
+
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with two wrong labels",
+                LabelsField: ["#work", "#notfun"],
+            }
+        )
+
+        self.helper.send_event(
+            room_id=self.room_id,
+            type=EventTypes.Message,
+            content={
+                "msgtype": "m.text",
+                "body": "with right label",
+                LabelsField: ["#fun"],
+            }
+        )
+
+        token = "s0_0_0_0_0_0_0_0_0"
+        request, channel = self.make_request(
+            "GET", "/rooms/%s/messages?access_token=x&from=%s&filter=%s" % (
+                self.room_id, token, message_filter
+            )
+        )
+        self.render(request)
+
+        return channel.json_body["chunk"]
+
 
 class RoomSearchTestCase(unittest.HomeserverTestCase):
     servlets = [
-- 
cgit 1.5.1


From dcc069a2e2540862c233a20037e3e59591a42431 Mon Sep 17 00:00:00 2001
From: Brendan Abolivier <babolivier@matrix.org>
Date: Wed, 30 Oct 2019 18:01:56 +0000
Subject: Lint

---
 synapse/storage/data_stores/main/events.py |  8 +----
 tests/api/test_filtering.py                | 16 +++-------
 tests/rest/client/v1/test_rooms.py         | 49 +++++++++++++++---------------
 tests/rest/client/v2_alpha/test_sync.py    | 12 ++++----
 4 files changed, 35 insertions(+), 50 deletions(-)

(limited to 'tests')

diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index f80b5f1a3f..2b900f1ce1 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -2486,13 +2486,7 @@ class EventsStore(
         return self._simple_insert_many_txn(
             txn=txn,
             table="event_labels",
-            values=[
-                {
-                    "event_id": event_id,
-                    "label": label,
-                }
-                for label in labels
-            ],
+            values=[{"event_id": event_id, "label": label} for label in labels],
         )
 
 
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 66b3c828db..e004ab1ee5 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -329,9 +329,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={
-                LabelsField: ["#fun"]
-            },
+            content={LabelsField: ["#fun"]},
         )
 
         self.assertTrue(Filter(definition).check(event))
@@ -340,9 +338,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={
-                LabelsField: ["#notfun"]
-            },
+            content={LabelsField: ["#notfun"]},
         )
 
         self.assertFalse(Filter(definition).check(event))
@@ -353,9 +349,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={
-                LabelsField: ["#fun"]
-            },
+            content={LabelsField: ["#fun"]},
         )
 
         self.assertFalse(Filter(definition).check(event))
@@ -364,9 +358,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={
-                LabelsField: ["#notfun"]
-            },
+            content={LabelsField: ["#notfun"]},
         )
 
         self.assertTrue(Filter(definition).check(event))
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index ba2008497e..188f47bd7d 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -813,10 +813,9 @@ class RoomMessageListTestCase(RoomBase):
 
     def test_filter_labels(self):
         """Test that we can filter by a label."""
-        message_filter = json.dumps({
-            "types": [EventTypes.Message],
-            "org.matrix.labels": ["#fun"],
-        })
+        message_filter = json.dumps(
+            {"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]}
+        )
 
         events = self._test_filter_labels(message_filter)
 
@@ -826,25 +825,28 @@ class RoomMessageListTestCase(RoomBase):
 
     def test_filter_not_labels(self):
         """Test that we can filter by the absence of a label."""
-        message_filter = json.dumps({
-            "types": [EventTypes.Message],
-            "org.matrix.not_labels": ["#fun"],
-        })
+        message_filter = json.dumps(
+            {"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]}
+        )
 
         events = self._test_filter_labels(message_filter)
 
         self.assertEqual(len(events), 3, [event["content"] for event in events])
         self.assertEqual(events[0]["content"]["body"], "without label", events[0])
         self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
-        self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2])
+        self.assertEqual(
+            events[2]["content"]["body"], "with two wrong labels", events[2]
+        )
 
     def test_filter_labels_not_labels(self):
         """Test that we can filter by both a label and the absence of another label."""
-        sync_filter = json.dumps({
-            "types": [EventTypes.Message],
-            "org.matrix.labels": ["#work"],
-            "org.matrix.not_labels": ["#notfun"],
-        })
+        sync_filter = json.dumps(
+            {
+                "types": [EventTypes.Message],
+                "org.matrix.labels": ["#work"],
+                "org.matrix.not_labels": ["#notfun"],
+            }
+        )
 
         events = self._test_filter_labels(sync_filter)
 
@@ -859,16 +861,13 @@ class RoomMessageListTestCase(RoomBase):
                 "msgtype": "m.text",
                 "body": "with right label",
                 LabelsField: ["#fun"],
-            }
+            },
         )
 
         self.helper.send_event(
             room_id=self.room_id,
             type=EventTypes.Message,
-            content={
-                "msgtype": "m.text",
-                "body": "without label",
-            }
+            content={"msgtype": "m.text", "body": "without label"},
         )
 
         self.helper.send_event(
@@ -878,7 +877,7 @@ class RoomMessageListTestCase(RoomBase):
                 "msgtype": "m.text",
                 "body": "with wrong label",
                 LabelsField: ["#work"],
-            }
+            },
         )
 
         self.helper.send_event(
@@ -888,7 +887,7 @@ class RoomMessageListTestCase(RoomBase):
                 "msgtype": "m.text",
                 "body": "with two wrong labels",
                 LabelsField: ["#work", "#notfun"],
-            }
+            },
         )
 
         self.helper.send_event(
@@ -898,14 +897,14 @@ class RoomMessageListTestCase(RoomBase):
                 "msgtype": "m.text",
                 "body": "with right label",
                 LabelsField: ["#fun"],
-            }
+            },
         )
 
         token = "s0_0_0_0_0_0_0_0_0"
         request, channel = self.make_request(
-            "GET", "/rooms/%s/messages?access_token=x&from=%s&filter=%s" % (
-                self.room_id, token, message_filter
-            )
+            "GET",
+            "/rooms/%s/messages?access_token=x&from=%s&filter=%s"
+            % (self.room_id, token, message_filter),
         )
         self.render(request)
 
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index a1aa7d87bd..c5c199d412 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -13,10 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+
 from mock import Mock
 
-from synapse.api.constants import EventTypes, LabelsField
 import synapse.rest.admin
+from synapse.api.constants import EventTypes, LabelsField
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import sync
 
@@ -121,7 +122,9 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(events), 3, [event["content"] for event in events])
         self.assertEqual(events[0]["content"]["body"], "without label", events[0])
         self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
-        self.assertEqual(events[2]["content"]["body"], "with two wrong labels", events[2])
+        self.assertEqual(
+            events[2]["content"]["body"], "with two wrong labels", events[2]
+        )
 
     def test_sync_filter_labels_not_labels(self):
         """Test that we can filter by both a label and the absence of another label."""
@@ -162,10 +165,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
         self.helper.send_event(
             room_id=room_id,
             type=EventTypes.Message,
-            content={
-                "msgtype": "m.text",
-                "body": "without label",
-            },
+            content={"msgtype": "m.text", "body": "without label"},
             tok=tok,
         )
 
-- 
cgit 1.5.1


From bb6cec27a5ac6d5d6d5f67df21610a63745ac0a9 Mon Sep 17 00:00:00 2001
From: Hubert Chathi <hubert@uhoreg.ca>
Date: Wed, 30 Oct 2019 14:57:34 -0400
Subject: rename get_devices_by_remote to get_device_updates_by_remote

---
 synapse/federation/sender/per_destination_queue.py |  4 ++--
 synapse/storage/data_stores/main/devices.py        |  8 ++++----
 tests/handlers/test_typing.py                      |  4 ++--
 tests/storage/test_devices.py                      | 12 ++++++------
 4 files changed, 14 insertions(+), 14 deletions(-)

(limited to 'tests')

diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index d5d4a60c88..6e3012cd41 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -359,7 +359,7 @@ class PerDestinationQueue(object):
         last_device_list = self._last_device_list_stream_id
 
         # Retrieve list of new device updates to send to the destination
-        now_stream_id, results = yield self._store.get_devices_by_remote(
+        now_stream_id, results = yield self._store.get_device_updates_by_remote(
             self._destination, last_device_list, limit=limit
         )
         edus = [
@@ -372,7 +372,7 @@ class PerDestinationQueue(object):
             for (edu_type, content) in results
         ]
 
-        assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+        assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
 
         return (edus, now_stream_id)
 
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 0b12bc58c4..717eab4159 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -91,7 +91,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
     @trace
     @defer.inlineCallbacks
-    def get_devices_by_remote(self, destination, from_stream_id, limit):
+    def get_device_updates_by_remote(self, destination, from_stream_id, limit):
         """Get a stream of device updates to send to the given remote server.
 
         Args:
@@ -123,8 +123,8 @@ class DeviceWorkerStore(SQLBaseStore):
         # stream_id; the rationale being that such a large device list update
         # is likely an error.
         updates = yield self.runInteraction(
-            "get_devices_by_remote",
-            self._get_devices_by_remote_txn,
+            "get_device_updates_by_remote",
+            self._get_device_updates_by_remote_txn,
             destination,
             from_stream_id,
             now_stream_id,
@@ -241,7 +241,7 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return now_stream_id, results
 
-    def _get_devices_by_remote_txn(
+    def _get_device_updates_by_remote_txn(
         self, txn, destination, from_stream_id, now_stream_id, limit
     ):
         """Return device update information for a given remote destination
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f360c8e965..5ec568f4e6 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -73,7 +73,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
                         "get_received_txn_response",
                         "set_received_txn_response",
                         "get_destination_retry_timings",
-                        "get_devices_by_remote",
+                        "get_device_updates_by_remote",
                         # Bits that user_directory needs
                         "get_user_directory_stream_pos",
                         "get_current_state_deltas",
@@ -109,7 +109,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             retry_timings_res
         )
 
-        self.datastore.get_devices_by_remote.return_value = (0, [])
+        self.datastore.get_device_updates_by_remote.return_value = (0, [])
 
         def get_received_txn_response(*args):
             return defer.succeed(None)
diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py
index 039cc79357..6f8d990959 100644
--- a/tests/storage/test_devices.py
+++ b/tests/storage/test_devices.py
@@ -72,7 +72,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
     @defer.inlineCallbacks
-    def test_get_devices_by_remote(self):
+    def test_get_device_updates_by_remote(self):
         device_ids = ["device_id1", "device_id2"]
 
         # Add two device updates with a single stream_id
@@ -81,7 +81,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         )
 
         # Get all device updates ever meant for this remote
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "somehost", -1, limit=100
         )
 
@@ -89,7 +89,7 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         self._check_devices_in_updates(device_ids, device_updates)
 
     @defer.inlineCallbacks
-    def test_get_devices_by_remote_limited(self):
+    def test_get_device_updates_by_remote_limited(self):
         # Test breaking the update limit in 1, 101, and 1 device_id segments
 
         # first add one device
@@ -115,20 +115,20 @@ class DeviceStoreTestCase(tests.unittest.TestCase):
         #
 
         # first we should get a single update
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "someotherhost", -1, limit=100
         )
         self._check_devices_in_updates(device_ids1, device_updates)
 
         # Then we should get an empty list back as the 101 devices broke the limit
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "someotherhost", now_stream_id, limit=100
         )
         self.assertEqual(len(device_updates), 0)
 
         # The 101 devices should've been cleared, so we should now just get one device
         # update
-        now_stream_id, device_updates = yield self.store.get_devices_by_remote(
+        now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
             "someotherhost", now_stream_id, limit=100
         )
         self._check_devices_in_updates(device_ids3, device_updates)
-- 
cgit 1.5.1


From c6dbca2422bf77ccbf0b52d9245d28c258dac4f3 Mon Sep 17 00:00:00 2001
From: Brendan Abolivier <babolivier@matrix.org>
Date: Fri, 1 Nov 2019 10:30:51 +0000
Subject: Incorporate review

---
 changelog.d/6301.feature                   |  2 +-
 synapse/api/constants.py                   |  5 ++++-
 synapse/api/filtering.py                   |  6 ++++--
 synapse/storage/data_stores/main/events.py | 12 ++++++++++--
 tests/api/test_filtering.py                | 10 +++++-----
 tests/rest/client/v1/test_rooms.py         | 10 +++++-----
 tests/rest/client/v2_alpha/test_sync.py    | 10 +++++-----
 7 files changed, 34 insertions(+), 21 deletions(-)

(limited to 'tests')

diff --git a/changelog.d/6301.feature b/changelog.d/6301.feature
index b7ff3fad3b..78a187a1dc 100644
--- a/changelog.d/6301.feature
+++ b/changelog.d/6301.feature
@@ -1 +1 @@
-Implement label-based filtering.
+Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)).
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 999ec02fd9..cf4ce5f5a2 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -140,4 +140,7 @@ class LimitBlockingTypes(object):
     HS_DISABLED = "hs_disabled"
 
 
-LabelsField = "org.matrix.labels"
+class EventContentFields(object):
+    """Fields found in events' content, regardless of type."""
+    # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
+    Labels = "org.matrix.labels"
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index bd91b9f018..30a7ee0a7a 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -20,7 +20,7 @@ from jsonschema import FormatChecker
 
 from twisted.internet import defer
 
-from synapse.api.constants import LabelsField
+from synapse.api.constants import EventContentFields
 from synapse.api.errors import SynapseError
 from synapse.storage.presence import UserPresenceState
 from synapse.types import RoomID, UserID
@@ -67,6 +67,8 @@ ROOM_EVENT_FILTER_SCHEMA = {
         "contains_url": {"type": "boolean"},
         "lazy_load_members": {"type": "boolean"},
         "include_redundant_members": {"type": "boolean"},
+        # Include or exclude events with the provided labels.
+        # cf https://github.com/matrix-org/matrix-doc/pull/2326
         "org.matrix.labels": {"type": "array", "items": {"type": "string"}},
         "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
     },
@@ -307,7 +309,7 @@ class Filter(object):
             content = event.get("content", {})
             # check if there is a string url field in the content for filtering purposes
             contains_url = isinstance(content.get("url"), text_type)
-            labels = content.get(LabelsField, [])
+            labels = content.get(EventContentFields.Labels, [])
 
         return self.check_fields(room_id, sender, ev_type, labels, contains_url)
 
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 2b900f1ce1..42ffa9066a 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -29,7 +29,7 @@ from prometheus_client import Counter, Histogram
 from twisted.internet import defer
 
 import synapse.metrics
-from synapse.api.constants import EventTypes, LabelsField
+from synapse.api.constants import EventTypes, EventContentFields
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
@@ -1491,7 +1491,7 @@ class EventsStore(
             self._handle_event_relations(txn, event)
 
             # Store the labels for this event.
-            labels = event.content.get(LabelsField)
+            labels = event.content.get(EventContentFields.Labels)
             if labels:
                 self.insert_labels_for_event_txn(txn, event.event_id, labels)
 
@@ -2483,6 +2483,14 @@ class EventsStore(
         )
 
     def insert_labels_for_event_txn(self, txn, event_id, labels):
+        """Store the mapping between an event's ID and its labels, with one row per
+        (event_id, label) tuple.
+
+        Args:
+            txn (LoggingTransaction): The transaction to execute.
+            event_id (str): The event's ID.
+            labels (list[str]): A list of text labels.
+        """
         return self._simple_insert_many_txn(
             txn=txn,
             table="event_labels",
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index e004ab1ee5..8ec48c4154 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -19,7 +19,7 @@ import jsonschema
 
 from twisted.internet import defer
 
-from synapse.api.constants import LabelsField
+from synapse.api.constants import EventContentFields
 from synapse.api.errors import SynapseError
 from synapse.api.filtering import Filter
 from synapse.events import FrozenEvent
@@ -329,7 +329,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={LabelsField: ["#fun"]},
+            content={EventContentFields.Labels: ["#fun"]},
         )
 
         self.assertTrue(Filter(definition).check(event))
@@ -338,7 +338,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={LabelsField: ["#notfun"]},
+            content={EventContentFields.Labels: ["#notfun"]},
         )
 
         self.assertFalse(Filter(definition).check(event))
@@ -349,7 +349,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={LabelsField: ["#fun"]},
+            content={EventContentFields.Labels: ["#fun"]},
         )
 
         self.assertFalse(Filter(definition).check(event))
@@ -358,7 +358,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={LabelsField: ["#notfun"]},
+            content={EventContentFields.Labels: ["#notfun"]},
         )
 
         self.assertTrue(Filter(definition).check(event))
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 188f47bd7d..0dc0faa0e5 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse
 from twisted.internet import defer
 
 import synapse.rest.admin
-from synapse.api.constants import EventTypes, LabelsField, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.rest.client.v1 import login, profile, room
 
 from tests import unittest
@@ -860,7 +860,7 @@ class RoomMessageListTestCase(RoomBase):
             content={
                 "msgtype": "m.text",
                 "body": "with right label",
-                LabelsField: ["#fun"],
+                EventContentFields.Labels: ["#fun"],
             },
         )
 
@@ -876,7 +876,7 @@ class RoomMessageListTestCase(RoomBase):
             content={
                 "msgtype": "m.text",
                 "body": "with wrong label",
-                LabelsField: ["#work"],
+                EventContentFields.Labels: ["#work"],
             },
         )
 
@@ -886,7 +886,7 @@ class RoomMessageListTestCase(RoomBase):
             content={
                 "msgtype": "m.text",
                 "body": "with two wrong labels",
-                LabelsField: ["#work", "#notfun"],
+                EventContentFields.Labels: ["#work", "#notfun"],
             },
         )
 
@@ -896,7 +896,7 @@ class RoomMessageListTestCase(RoomBase):
             content={
                 "msgtype": "m.text",
                 "body": "with right label",
-                LabelsField: ["#fun"],
+                EventContentFields.Labels: ["#fun"],
             },
         )
 
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index c5c199d412..c3c6f75ced 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -17,7 +17,7 @@ import json
 from mock import Mock
 
 import synapse.rest.admin
-from synapse.api.constants import EventTypes, LabelsField
+from synapse.api.constants import EventContentFields, EventTypes
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import sync
 
@@ -157,7 +157,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             content={
                 "msgtype": "m.text",
                 "body": "with right label",
-                LabelsField: ["#fun"],
+                EventContentFields.Labels: ["#fun"],
             },
             tok=tok,
         )
@@ -175,7 +175,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             content={
                 "msgtype": "m.text",
                 "body": "with wrong label",
-                LabelsField: ["#work"],
+                EventContentFields.Labels: ["#work"],
             },
             tok=tok,
         )
@@ -186,7 +186,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             content={
                 "msgtype": "m.text",
                 "body": "with two wrong labels",
-                LabelsField: ["#work", "#notfun"],
+                EventContentFields.Labels: ["#work", "#notfun"],
             },
             tok=tok,
         )
@@ -197,7 +197,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             content={
                 "msgtype": "m.text",
                 "body": "with right label",
-                LabelsField: ["#fun"],
+                EventContentFields.Labels: ["#fun"],
             },
             tok=tok,
         )
-- 
cgit 1.5.1


From 1cb84c6486a5131dd284f341bb657434becda255 Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Fri, 1 Nov 2019 14:07:44 +0000
Subject: Support for routing outbound HTTP requests via a proxy (#6239)

The `http_proxy` and `HTTPS_PROXY` env vars can be set to a `host[:port]` value which should point to a proxy.

The address of the proxy should be excluded from IP blacklists such as the `url_preview_ip_range_blacklist`.

The proxy will then be used for
 * push
 * url previews
 * phone-home stats
 * recaptcha validation
 * CAS auth validation

It will *not* be used for:
 * Application Services
 * Identity servers
 * Outbound federation
 * In worker configurations, connections from workers to masters

Fixes #4198.
---
 changelog.d/6238.feature                           |   1 +
 synapse/app/homeserver.py                          |   2 +-
 synapse/handlers/ui_auth/checkers.py               |   2 +-
 synapse/http/client.py                             |  17 +-
 synapse/http/connectproxyclient.py                 | 195 ++++++++++++
 synapse/http/proxyagent.py                         | 195 ++++++++++++
 synapse/push/httppusher.py                         |   2 +-
 synapse/rest/client/v1/login.py                    |   2 +-
 synapse/rest/media/v1/preview_url_resource.py      |   2 +
 synapse/server.py                                  |   9 +
 synapse/server.pyi                                 |   9 +
 tests/http/__init__.py                             |  17 ++
 .../federation/test_matrix_federation_agent.py     |  11 +-
 tests/http/test_proxyagent.py                      | 334 +++++++++++++++++++++
 tests/push/test_http.py                            |   2 +-
 tests/server.py                                    |  24 +-
 16 files changed, 812 insertions(+), 12 deletions(-)
 create mode 100644 changelog.d/6238.feature
 create mode 100644 synapse/http/connectproxyclient.py
 create mode 100644 synapse/http/proxyagent.py
 create mode 100644 tests/http/test_proxyagent.py

(limited to 'tests')

diff --git a/changelog.d/6238.feature b/changelog.d/6238.feature
new file mode 100644
index 0000000000..d225ac33b6
--- /dev/null
+++ b/changelog.d/6238.feature
@@ -0,0 +1 @@
+Add support for outbound http proxying via http_proxy/HTTPS_PROXY env vars.
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8997c1f9e7..8d28076d92 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -565,7 +565,7 @@ def run(hs):
             "Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)
         )
         try:
-            yield hs.get_simple_http_client().put_json(
+            yield hs.get_proxied_http_client().put_json(
                 hs.config.report_stats_endpoint, stats
             )
         except Exception as e:
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 29aa1e5aaf..8363d887a9 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -81,7 +81,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
     def __init__(self, hs):
         super().__init__(hs)
         self._enabled = bool(hs.config.recaptcha_private_key)
-        self._http_client = hs.get_simple_http_client()
+        self._http_client = hs.get_proxied_http_client()
         self._url = hs.config.recaptcha_siteverify_api
         self._secret = hs.config.recaptcha_private_key
 
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 2df5b383b5..d4c285445e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -45,6 +45,7 @@ from synapse.http import (
     cancelled_to_request_timed_out_error,
     redact_uri,
 )
+from synapse.http.proxyagent import ProxyAgent
 from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.opentracing import set_tag, start_active_span, tags
 from synapse.util.async_helpers import timeout_deferred
@@ -183,7 +184,15 @@ class SimpleHttpClient(object):
     using HTTP in Matrix
     """
 
-    def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
+    def __init__(
+        self,
+        hs,
+        treq_args={},
+        ip_whitelist=None,
+        ip_blacklist=None,
+        http_proxy=None,
+        https_proxy=None,
+    ):
         """
         Args:
             hs (synapse.server.HomeServer)
@@ -192,6 +201,8 @@ class SimpleHttpClient(object):
                 we may not request.
             ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
                request if it were otherwise caught in a blacklist.
+            http_proxy (bytes): proxy server to use for http connections. host[:port]
+            https_proxy (bytes): proxy server to use for https connections. host[:port]
         """
         self.hs = hs
 
@@ -236,11 +247,13 @@ class SimpleHttpClient(object):
         # The default context factory in Twisted 14.0.0 (which we require) is
         # BrowserLikePolicyForHTTPS which will do regular cert validation
         # 'like a browser'
-        self.agent = Agent(
+        self.agent = ProxyAgent(
             self.reactor,
             connectTimeout=15,
             contextFactory=self.hs.get_http_client_context_factory(),
             pool=pool,
+            http_proxy=http_proxy,
+            https_proxy=https_proxy,
         )
 
         if self._ip_blacklist:
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
new file mode 100644
index 0000000000..be7b2ceb8e
--- /dev/null
+++ b/synapse/http/connectproxyclient.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from zope.interface import implementer
+
+from twisted.internet import defer, protocol
+from twisted.internet.error import ConnectError
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.internet.protocol import connectionDone
+from twisted.web import http
+
+logger = logging.getLogger(__name__)
+
+
+class ProxyConnectError(ConnectError):
+    pass
+
+
+@implementer(IStreamClientEndpoint)
+class HTTPConnectProxyEndpoint(object):
+    """An Endpoint implementation which will send a CONNECT request to an http proxy
+
+    Wraps an existing HostnameEndpoint for the proxy.
+
+    When we get the connect() request from the connection pool (via the TLS wrapper),
+    we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
+    CONNECT request. Once that completes, we invoke the protocolFactory which was passed
+    in.
+
+    Args:
+        reactor: the Twisted reactor to use for the connection
+        proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
+            proxy
+        host (bytes): hostname that we want to CONNECT to
+        port (int): port that we want to connect to
+    """
+
+    def __init__(self, reactor, proxy_endpoint, host, port):
+        self._reactor = reactor
+        self._proxy_endpoint = proxy_endpoint
+        self._host = host
+        self._port = port
+
+    def __repr__(self):
+        return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
+
+    def connect(self, protocolFactory):
+        f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+        d = self._proxy_endpoint.connect(f)
+        # once the tcp socket connects successfully, we need to wait for the
+        # CONNECT to complete.
+        d.addCallback(lambda conn: f.on_connection)
+        return d
+
+
+class HTTPProxiedClientFactory(protocol.ClientFactory):
+    """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
+
+    Once the CONNECT completes, invokes the original ClientFactory to build the
+    HTTP Protocol object and run the rest of the connection.
+
+    Args:
+        dst_host (bytes): hostname that we want to CONNECT to
+        dst_port (int): port that we want to connect to
+        wrapped_factory (protocol.ClientFactory): The original Factory
+    """
+
+    def __init__(self, dst_host, dst_port, wrapped_factory):
+        self.dst_host = dst_host
+        self.dst_port = dst_port
+        self.wrapped_factory = wrapped_factory
+        self.on_connection = defer.Deferred()
+
+    def startedConnecting(self, connector):
+        return self.wrapped_factory.startedConnecting(connector)
+
+    def buildProtocol(self, addr):
+        wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+
+        return HTTPConnectProtocol(
+            self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+        )
+
+    def clientConnectionFailed(self, connector, reason):
+        logger.debug("Connection to proxy failed: %s", reason)
+        if not self.on_connection.called:
+            self.on_connection.errback(reason)
+        return self.wrapped_factory.clientConnectionFailed(connector, reason)
+
+    def clientConnectionLost(self, connector, reason):
+        logger.debug("Connection to proxy lost: %s", reason)
+        if not self.on_connection.called:
+            self.on_connection.errback(reason)
+        return self.wrapped_factory.clientConnectionLost(connector, reason)
+
+
+class HTTPConnectProtocol(protocol.Protocol):
+    """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
+
+    Args:
+        host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+            to put in the CONNECT request
+
+        port (int): The original HTTP(s) port to put in the CONNECT request
+
+        wrapped_protocol (interfaces.IProtocol): the original protocol (probably
+            HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+
+        connected_deferred (Deferred): a Deferred which will be callbacked with
+            wrapped_protocol when the CONNECT completes
+    """
+
+    def __init__(self, host, port, wrapped_protocol, connected_deferred):
+        self.host = host
+        self.port = port
+        self.wrapped_protocol = wrapped_protocol
+        self.connected_deferred = connected_deferred
+        self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+        self.http_setup_client.on_connected.addCallback(self.proxyConnected)
+
+    def connectionMade(self):
+        self.http_setup_client.makeConnection(self.transport)
+
+    def connectionLost(self, reason=connectionDone):
+        if self.wrapped_protocol.connected:
+            self.wrapped_protocol.connectionLost(reason)
+
+        self.http_setup_client.connectionLost(reason)
+
+        if not self.connected_deferred.called:
+            self.connected_deferred.errback(reason)
+
+    def proxyConnected(self, _):
+        self.wrapped_protocol.makeConnection(self.transport)
+
+        self.connected_deferred.callback(self.wrapped_protocol)
+
+        # Get any pending data from the http buf and forward it to the original protocol
+        buf = self.http_setup_client.clearLineBuffer()
+        if buf:
+            self.wrapped_protocol.dataReceived(buf)
+
+    def dataReceived(self, data):
+        # if we've set up the HTTP protocol, we can send the data there
+        if self.wrapped_protocol.connected:
+            return self.wrapped_protocol.dataReceived(data)
+
+        # otherwise, we must still be setting up the connection: send the data to the
+        # setup client
+        return self.http_setup_client.dataReceived(data)
+
+
+class HTTPConnectSetupClient(http.HTTPClient):
+    """HTTPClient protocol to send a CONNECT message for proxies and read the response.
+
+    Args:
+        host (bytes): The hostname to send in the CONNECT message
+        port (int): The port to send in the CONNECT message
+    """
+
+    def __init__(self, host, port):
+        self.host = host
+        self.port = port
+        self.on_connected = defer.Deferred()
+
+    def connectionMade(self):
+        logger.debug("Connected to proxy, sending CONNECT")
+        self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+        self.endHeaders()
+
+    def handleStatus(self, version, status, message):
+        logger.debug("Got Status: %s %s %s", status, message, version)
+        if status != b"200":
+            raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
+
+    def handleEndHeaders(self):
+        logger.debug("End Headers")
+        self.on_connected.callback(None)
+
+    def handleResponse(self, body):
+        pass
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
new file mode 100644
index 0000000000..332da02a8d
--- /dev/null
+++ b/synapse/http/proxyagent.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.python.failure import Failure
+from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.error import SchemeNotSupported
+from twisted.web.iweb import IAgent
+
+from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+
+logger = logging.getLogger(__name__)
+
+_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
+
+
+@implementer(IAgent)
+class ProxyAgent(_AgentBase):
+    """An Agent implementation which will use an HTTP proxy if one was requested
+
+    Args:
+        reactor: twisted reactor to place outgoing
+            connections.
+
+        contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+            verification parameters of OpenSSL.  The default is to use a
+            `BrowserLikePolicyForHTTPS`, so unless you have special
+            requirements you can leave this as-is.
+
+        connectTimeout (float): The amount of time that this Agent will wait
+            for the peer to accept a connection.
+
+        bindAddress (bytes): The local address for client sockets to bind to.
+
+        pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+            non-persistent pool instance will be created.
+    """
+
+    def __init__(
+        self,
+        reactor,
+        contextFactory=BrowserLikePolicyForHTTPS(),
+        connectTimeout=None,
+        bindAddress=None,
+        pool=None,
+        http_proxy=None,
+        https_proxy=None,
+    ):
+        _AgentBase.__init__(self, reactor, pool)
+
+        self._endpoint_kwargs = {}
+        if connectTimeout is not None:
+            self._endpoint_kwargs["timeout"] = connectTimeout
+        if bindAddress is not None:
+            self._endpoint_kwargs["bindAddress"] = bindAddress
+
+        self.http_proxy_endpoint = _http_proxy_endpoint(
+            http_proxy, reactor, **self._endpoint_kwargs
+        )
+
+        self.https_proxy_endpoint = _http_proxy_endpoint(
+            https_proxy, reactor, **self._endpoint_kwargs
+        )
+
+        self._policy_for_https = contextFactory
+        self._reactor = reactor
+
+    def request(self, method, uri, headers=None, bodyProducer=None):
+        """
+        Issue a request to the server indicated by the given uri.
+
+        Supports `http` and `https` schemes.
+
+        An existing connection from the connection pool may be used or a new one may be
+        created.
+
+        See also: twisted.web.iweb.IAgent.request
+
+        Args:
+            method (bytes): The request method to use, such as `GET`, `POST`, etc
+
+            uri (bytes): The location of the resource to request.
+
+            headers (Headers|None): Extra headers to send with the request
+
+            bodyProducer (IBodyProducer|None): An object which can generate bytes to
+                make up the body of this request (for example, the properly encoded
+                contents of a file for a file upload). Or, None if the request is to
+                have no body.
+
+        Returns:
+            Deferred[IResponse]: completes when the header of the response has
+                 been received (regardless of the response status code).
+        """
+        uri = uri.strip()
+        if not _VALID_URI.match(uri):
+            raise ValueError("Invalid URI {!r}".format(uri))
+
+        parsed_uri = URI.fromBytes(uri)
+        pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+        request_path = parsed_uri.originForm
+
+        if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
+            # Cache *all* connections under the same key, since we are only
+            # connecting to a single destination, the proxy:
+            pool_key = ("http-proxy", self.http_proxy_endpoint)
+            endpoint = self.http_proxy_endpoint
+            request_path = uri
+        elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
+            endpoint = HTTPConnectProxyEndpoint(
+                self._reactor,
+                self.https_proxy_endpoint,
+                parsed_uri.host,
+                parsed_uri.port,
+            )
+        else:
+            # not using a proxy
+            endpoint = HostnameEndpoint(
+                self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
+            )
+
+        logger.debug("Requesting %s via %s", uri, endpoint)
+
+        if parsed_uri.scheme == b"https":
+            tls_connection_creator = self._policy_for_https.creatorForNetloc(
+                parsed_uri.host, parsed_uri.port
+            )
+            endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+        elif parsed_uri.scheme == b"http":
+            pass
+        else:
+            return defer.fail(
+                Failure(
+                    SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
+                )
+            )
+
+        return self._requestWithEndpoint(
+            pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
+        )
+
+
+def _http_proxy_endpoint(proxy, reactor, **kwargs):
+    """Parses an http proxy setting and returns an endpoint for the proxy
+
+    Args:
+        proxy (bytes|None):  the proxy setting
+        reactor: reactor to be used to connect to the proxy
+        kwargs: other args to be passed to HostnameEndpoint
+
+    Returns:
+        interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
+            or None
+    """
+    if proxy is None:
+        return None
+
+    # currently we only support hostname:port. Some apps also support
+    # protocol://<host>[:port], which allows a way of requiring a TLS connection to the
+    # proxy.
+
+    host, port = parse_host_port(proxy, default_port=1080)
+    return HostnameEndpoint(reactor, host, port, **kwargs)
+
+
+def parse_host_port(hostport, default_port=None):
+    # could have sworn we had one of these somewhere else...
+    if b":" in hostport:
+        host, port = hostport.rsplit(b":", 1)
+        try:
+            port = int(port)
+            return host, port
+        except ValueError:
+            # the thing after the : wasn't a valid port; presumably this is an
+            # IPv6 address.
+            pass
+
+    return hostport, default_port
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 7dde2ad055..e994037be6 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -103,7 +103,7 @@ class HttpPusher(object):
         if "url" not in self.data:
             raise PusherConfigException("'url' required in data for HTTP pusher")
         self.url = self.data["url"]
-        self.http_client = hs.get_simple_http_client()
+        self.http_client = hs.get_proxied_http_client()
         self.data_minus_url = {}
         self.data_minus_url.update(self.data)
         del self.data_minus_url["url"]
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 00a7dd6d09..24a0ce74f2 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -381,7 +381,7 @@ class CasTicketServlet(RestServlet):
         self.cas_displayname_attribute = hs.config.cas_displayname_attribute
         self.cas_required_attributes = hs.config.cas_required_attributes
         self._sso_auth_handler = SSOAuthHandler(hs)
-        self._http_client = hs.get_simple_http_client()
+        self._http_client = hs.get_proxied_http_client()
 
     @defer.inlineCallbacks
     def on_GET(self, request):
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 5a25b6b3fc..531d923f76 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource):
             treq_args={"browser_like_redirects": True},
             ip_whitelist=hs.config.url_preview_ip_range_whitelist,
             ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+            http_proxy=os.getenv("http_proxy"),
+            https_proxy=os.getenv("HTTPS_PROXY"),
         )
         self.media_repo = media_repo
         self.primary_base_path = media_repo.primary_base_path
diff --git a/synapse/server.py b/synapse/server.py
index 0b81af646c..f8aeebcff8 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -23,6 +23,7 @@
 # Imports required for the default HomeServer() implementation
 import abc
 import logging
+import os
 
 from twisted.enterprise import adbapi
 from twisted.mail.smtp import sendmail
@@ -168,6 +169,7 @@ class HomeServer(object):
         "filtering",
         "http_client_context_factory",
         "simple_http_client",
+        "proxied_http_client",
         "media_repository",
         "media_repository_resource",
         "federation_transport_client",
@@ -311,6 +313,13 @@ class HomeServer(object):
     def build_simple_http_client(self):
         return SimpleHttpClient(self)
 
+    def build_proxied_http_client(self):
+        return SimpleHttpClient(
+            self,
+            http_proxy=os.getenv("http_proxy"),
+            https_proxy=os.getenv("HTTPS_PROXY"),
+        )
+
     def build_room_creation_handler(self):
         return RoomCreationHandler(self)
 
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 83d1f11283..b5e0b57095 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -12,6 +12,7 @@ import synapse.handlers.message
 import synapse.handlers.room
 import synapse.handlers.room_member
 import synapse.handlers.set_password
+import synapse.http.client
 import synapse.rest.media.v1.media_repository
 import synapse.server_notices.server_notices_manager
 import synapse.server_notices.server_notices_sender
@@ -38,6 +39,14 @@ class HomeServer(object):
         pass
     def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
         pass
+    def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
+        """Fetch an HTTP client implementation which doesn't do any blacklisting
+        or support any HTTP_PROXY settings"""
+        pass
+    def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
+        """Fetch an HTTP client implementation which doesn't do any blacklisting
+        but does support HTTP_PROXY settings"""
+        pass
     def get_deactivate_account_handler(
         self,
     ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
diff --git a/tests/http/__init__.py b/tests/http/__init__.py
index 2d5dba6464..2096ba3c91 100644
--- a/tests/http/__init__.py
+++ b/tests/http/__init__.py
@@ -20,6 +20,23 @@ from zope.interface import implementer
 from OpenSSL import SSL
 from OpenSSL.SSL import Connection
 from twisted.internet.interfaces import IOpenSSLServerConnectionCreator
+from twisted.internet.ssl import Certificate, trustRootFromCertificates
+from twisted.web.client import BrowserLikePolicyForHTTPS  # noqa: F401
+from twisted.web.iweb import IPolicyForHTTPS  # noqa: F401
+
+
+def get_test_https_policy():
+    """Get a test IPolicyForHTTPS which trusts the test CA cert
+
+    Returns:
+        IPolicyForHTTPS
+    """
+    ca_file = get_test_ca_cert_file()
+    with open(ca_file) as stream:
+        content = stream.read()
+    cert = Certificate.loadPEM(content)
+    trust_root = trustRootFromCertificates([cert])
+    return BrowserLikePolicyForHTTPS(trustRoot=trust_root)
 
 
 def get_test_ca_cert_file():
diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py
index 71d7025264..cfcd98ff7d 100644
--- a/tests/http/federation/test_matrix_federation_agent.py
+++ b/tests/http/federation/test_matrix_federation_agent.py
@@ -124,19 +124,24 @@ class MatrixFederationAgentTests(unittest.TestCase):
             FakeTransport(client_protocol, self.reactor, server_tls_protocol)
         )
 
+        # grab a hold of the TLS connection, in case it gets torn down
+        server_tls_connection = server_tls_protocol._tlsConnection
+
+        # fish the test server back out of the server-side TLS protocol.
+        http_protocol = server_tls_protocol.wrappedProtocol
+
         # give the reactor a pump to get the TLS juices flowing.
         self.reactor.pump((0.1,))
 
         # check the SNI
-        server_name = server_tls_protocol._tlsConnection.get_servername()
+        server_name = server_tls_connection.get_servername()
         self.assertEqual(
             server_name,
             expected_sni,
             "Expected SNI %s but got %s" % (expected_sni, server_name),
         )
 
-        # fish the test server back out of the server-side TLS protocol.
-        return server_tls_protocol.wrappedProtocol
+        return http_protocol
 
     @defer.inlineCallbacks
     def _make_get_request(self, uri):
diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py
new file mode 100644
index 0000000000..22abf76515
--- /dev/null
+++ b/tests/http/test_proxyagent.py
@@ -0,0 +1,334 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+import treq
+
+from twisted.internet import interfaces  # noqa: F401
+from twisted.internet.protocol import Factory
+from twisted.protocols.tls import TLSMemoryBIOFactory
+from twisted.web.http import HTTPChannel
+
+from synapse.http.proxyagent import ProxyAgent
+
+from tests.http import TestServerTLSConnectionFactory, get_test_https_policy
+from tests.server import FakeTransport, ThreadedMemoryReactorClock
+from tests.unittest import TestCase
+
+logger = logging.getLogger(__name__)
+
+HTTPFactory = Factory.forProtocol(HTTPChannel)
+
+
+class MatrixFederationAgentTests(TestCase):
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+
+    def _make_connection(
+        self, client_factory, server_factory, ssl=False, expected_sni=None
+    ):
+        """Builds a test server, and completes the outgoing client connection
+
+        Args:
+            client_factory (interfaces.IProtocolFactory): the the factory that the
+                application is trying to use to make the outbound connection. We will
+                invoke it to build the client Protocol
+
+            server_factory (interfaces.IProtocolFactory): a factory to build the
+                server-side protocol
+
+            ssl (bool): If true, we will expect an ssl connection and wrap
+                server_factory with a TLSMemoryBIOFactory
+
+            expected_sni (bytes|None): the expected SNI value
+
+        Returns:
+            IProtocol: the server Protocol returned by server_factory
+        """
+        if ssl:
+            server_factory = _wrap_server_factory_for_tls(server_factory)
+
+        server_protocol = server_factory.buildProtocol(None)
+
+        # now, tell the client protocol factory to build the client protocol,
+        # and wire the output of said protocol up to the server via
+        # a FakeTransport.
+        #
+        # Normally this would be done by the TCP socket code in Twisted, but we are
+        # stubbing that out here.
+        client_protocol = client_factory.buildProtocol(None)
+        client_protocol.makeConnection(
+            FakeTransport(server_protocol, self.reactor, client_protocol)
+        )
+
+        # tell the server protocol to send its stuff back to the client, too
+        server_protocol.makeConnection(
+            FakeTransport(client_protocol, self.reactor, server_protocol)
+        )
+
+        if ssl:
+            http_protocol = server_protocol.wrappedProtocol
+            tls_connection = server_protocol._tlsConnection
+        else:
+            http_protocol = server_protocol
+            tls_connection = None
+
+        # give the reactor a pump to get the TLS juices flowing (if needed)
+        self.reactor.advance(0)
+
+        if expected_sni is not None:
+            server_name = tls_connection.get_servername()
+            self.assertEqual(
+                server_name,
+                expected_sni,
+                "Expected SNI %s but got %s" % (expected_sni, server_name),
+            )
+
+        return http_protocol
+
+    def test_http_request(self):
+        agent = ProxyAgent(self.reactor)
+
+        self.reactor.lookups["test.com"] = "1.2.3.4"
+        d = agent.request(b"GET", b"http://test.com")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 80)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+    def test_https_request(self):
+        agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
+
+        self.reactor.lookups["test.com"] = "1.2.3.4"
+        d = agent.request(b"GET", b"https://test.com/abc")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 443)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory,
+            _get_test_protocol_factory(),
+            ssl=True,
+            expected_sni=b"test.com",
+        )
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/abc")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+    def test_http_request_via_proxy(self):
+        agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
+
+        self.reactor.lookups["proxy.com"] = "1.2.3.5"
+        d = agent.request(b"GET", b"http://test.com")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 8888)
+
+        # make a test server, and wire up the client
+        http_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"http://test.com")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+    def test_https_request_via_proxy(self):
+        agent = ProxyAgent(
+            self.reactor,
+            contextFactory=get_test_https_policy(),
+            https_proxy=b"proxy.com",
+        )
+
+        self.reactor.lookups["proxy.com"] = "1.2.3.5"
+        d = agent.request(b"GET", b"https://test.com/abc")
+
+        # there should be a pending TCP connection
+        clients = self.reactor.tcpClients
+        self.assertEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients[0]
+        self.assertEqual(host, "1.2.3.5")
+        self.assertEqual(port, 1080)
+
+        # make a test HTTP server, and wire up the client
+        proxy_server = self._make_connection(
+            client_factory, _get_test_protocol_factory()
+        )
+
+        # fish the transports back out so that we can do the old switcheroo
+        s2c_transport = proxy_server.transport
+        client_protocol = s2c_transport.other
+        c2s_transport = client_protocol.transport
+
+        # the FakeTransport is async, so we need to pump the reactor
+        self.reactor.advance(0)
+
+        # now there should be a pending CONNECT request
+        self.assertEqual(len(proxy_server.requests), 1)
+
+        request = proxy_server.requests[0]
+        self.assertEqual(request.method, b"CONNECT")
+        self.assertEqual(request.path, b"test.com:443")
+
+        # tell the proxy server not to close the connection
+        proxy_server.persistent = True
+
+        # this just stops the http Request trying to do a chunked response
+        # request.setHeader(b"Content-Length", b"0")
+        request.finish()
+
+        # now we can replace the proxy channel with a new, SSL-wrapped HTTP channel
+        ssl_factory = _wrap_server_factory_for_tls(_get_test_protocol_factory())
+        ssl_protocol = ssl_factory.buildProtocol(None)
+        http_server = ssl_protocol.wrappedProtocol
+
+        ssl_protocol.makeConnection(
+            FakeTransport(client_protocol, self.reactor, ssl_protocol)
+        )
+        c2s_transport.other = ssl_protocol
+
+        self.reactor.advance(0)
+
+        server_name = ssl_protocol._tlsConnection.get_servername()
+        expected_sni = b"test.com"
+        self.assertEqual(
+            server_name,
+            expected_sni,
+            "Expected SNI %s but got %s" % (expected_sni, server_name),
+        )
+
+        # now there should be a pending request
+        self.assertEqual(len(http_server.requests), 1)
+
+        request = http_server.requests[0]
+        self.assertEqual(request.method, b"GET")
+        self.assertEqual(request.path, b"/abc")
+        self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
+        request.write(b"result")
+        request.finish()
+
+        self.reactor.advance(0)
+
+        resp = self.successResultOf(d)
+        body = self.successResultOf(treq.content(resp))
+        self.assertEqual(body, b"result")
+
+
+def _wrap_server_factory_for_tls(factory, sanlist=None):
+    """Wrap an existing Protocol Factory with a test TLSMemoryBIOFactory
+
+    The resultant factory will create a TLS server which presents a certificate
+    signed by our test CA, valid for the domains in `sanlist`
+
+    Args:
+        factory (interfaces.IProtocolFactory): protocol factory to wrap
+        sanlist (iterable[bytes]): list of domains the cert should be valid for
+
+    Returns:
+        interfaces.IProtocolFactory
+    """
+    if sanlist is None:
+        sanlist = [b"DNS:test.com"]
+
+    connection_creator = TestServerTLSConnectionFactory(sanlist=sanlist)
+    return TLSMemoryBIOFactory(
+        connection_creator, isClient=False, wrappedFactory=factory
+    )
+
+
+def _get_test_protocol_factory():
+    """Get a protocol Factory which will build an HTTPChannel
+
+    Returns:
+        interfaces.IProtocolFactory
+    """
+    server_factory = Factory.forProtocol(HTTPChannel)
+
+    # Request.finish expects the factory to have a 'log' method.
+    server_factory.log = _log_request
+
+    return server_factory
+
+
+def _log_request(request):
+    """Implements Factory.log, which is expected by Request.finish"""
+    logger.info("Completed request %s", request)
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8ce6bb62da..af2327fb66 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -50,7 +50,7 @@ class HTTPPusherTests(HomeserverTestCase):
         config = self.default_config()
         config["start_pushers"] = True
 
-        hs = self.setup_test_homeserver(config=config, simple_http_client=m)
+        hs = self.setup_test_homeserver(config=config, proxied_http_client=m)
 
         return hs
 
diff --git a/tests/server.py b/tests/server.py
index 469efb4edb..f878aeaada 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -395,11 +395,24 @@ class FakeTransport(object):
             self.disconnecting = True
             if self._protocol:
                 self._protocol.connectionLost(reason)
-            self.disconnected = True
+
+            # if we still have data to write, delay until that is done
+            if self.buffer:
+                logger.info(
+                    "FakeTransport: Delaying disconnect until buffer is flushed"
+                )
+            else:
+                self.disconnected = True
 
     def abortConnection(self):
         logger.info("FakeTransport: abortConnection()")
-        self.loseConnection()
+
+        if not self.disconnecting:
+            self.disconnecting = True
+            if self._protocol:
+                self._protocol.connectionLost(None)
+
+        self.disconnected = True
 
     def pauseProducing(self):
         if not self.producer:
@@ -430,6 +443,9 @@ class FakeTransport(object):
             self._reactor.callLater(0.0, _produce)
 
     def write(self, byt):
+        if self.disconnecting:
+            raise Exception("Writing to disconnecting FakeTransport")
+
         self.buffer = self.buffer + byt
 
         # always actually do the write asynchronously. Some protocols (notably the
@@ -474,6 +490,10 @@ class FakeTransport(object):
         if self.buffer and self.autoflush:
             self._reactor.callLater(0.0, self.flush)
 
+        if not self.buffer and self.disconnecting:
+            logger.info("FakeTransport: Buffer now empty, completing disconnect")
+            self.disconnected = True
+
 
 def connect_client(reactor: IReactorTCP, client_id: int) -> AccumulatingProtocol:
     """
-- 
cgit 1.5.1


From c6516adbe03a0acdd614ba6eb9d6f447dd4259e9 Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Fri, 1 Nov 2019 16:19:09 +0000
Subject: Factor out an _AsyncEventContextImpl (#6298)

The intention here is to make it clearer which fields we can expect to be
populated when: notably, that the _event_type etc aren't used for the
synchronous impl of EventContext.
---
 changelog.d/6298.misc          |   1 +
 synapse/events/snapshot.py     | 107 ++++++++++++++++-------------------------
 synapse/handlers/federation.py |  38 +++++++--------
 tests/test_federation.py       |   4 +-
 4 files changed, 65 insertions(+), 85 deletions(-)
 create mode 100644 changelog.d/6298.misc

(limited to 'tests')

diff --git a/changelog.d/6298.misc b/changelog.d/6298.misc
new file mode 100644
index 0000000000..d4190730b2
--- /dev/null
+++ b/changelog.d/6298.misc
@@ -0,0 +1 @@
+Refactor EventContext for clarity.
\ No newline at end of file
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 27cd8a63ff..a269de5482 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -37,9 +37,6 @@ class EventContext:
         delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
             (type, state_key) -> event_id. ``None`` for an outlier.
 
-        prev_state_events (?): XXX: is this ever set to anything other than
-            the empty list?
-
         app_service: FIXME
 
         _current_state_ids (dict[(str, str), str]|None):
@@ -51,36 +48,16 @@ class EventContext:
             The current state map excluding the current event. None if outlier
             or we haven't fetched the state from DB yet.
             (type, state_key) -> event_id
-
-        _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
-            been calculated. None if we haven't started calculating yet
-
-        _event_type (str): The type of the event the context is associated with.
-            Only set when state has not been fetched yet.
-
-        _event_state_key (str|None): The state_key of the event the context is
-            associated with. Only set when state has not been fetched yet.
-
-        _prev_state_id (str|None): If the event associated with the context is
-            a state event, then `_prev_state_id` is the event_id of the state
-            that was replaced.
-            Only set when state has not been fetched yet.
     """
 
     state_group = attr.ib(default=None)
     rejected = attr.ib(default=False)
     prev_group = attr.ib(default=None)
     delta_ids = attr.ib(default=None)
-    prev_state_events = attr.ib(default=attr.Factory(list))
     app_service = attr.ib(default=None)
 
-    _current_state_ids = attr.ib(default=None)
     _prev_state_ids = attr.ib(default=None)
-    _prev_state_id = attr.ib(default=None)
-
-    _event_type = attr.ib(default=None)
-    _event_state_key = attr.ib(default=None)
-    _fetching_state_deferred = attr.ib(default=None)
+    _current_state_ids = attr.ib(default=None)
 
     @staticmethod
     def with_state(
@@ -90,7 +67,6 @@ class EventContext:
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
             state_group=state_group,
-            fetching_state_deferred=defer.succeed(None),
             prev_group=prev_group,
             delta_ids=delta_ids,
         )
@@ -125,7 +101,6 @@ class EventContext:
             "rejected": self.rejected,
             "prev_group": self.prev_group,
             "delta_ids": _encode_state_dict(self.delta_ids),
-            "prev_state_events": self.prev_state_events,
             "app_service_id": self.app_service.id if self.app_service else None,
         }
 
@@ -141,7 +116,7 @@ class EventContext:
         Returns:
             EventContext
         """
-        context = EventContext(
+        context = _AsyncEventContextImpl(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
             prev_state_id=input["prev_state_id"],
@@ -151,7 +126,6 @@ class EventContext:
             prev_group=input["prev_group"],
             delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
-            prev_state_events=input["prev_state_events"],
         )
 
         app_service_id = input["app_service_id"]
@@ -170,14 +144,7 @@ class EventContext:
                 Maps a (type, state_key) to the event ID of the state event matching
                 this tuple.
         """
-
-        if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(
-                self._fill_out_state, store
-            )
-
-        yield make_deferred_yieldable(self._fetching_state_deferred)
-
+        yield self._ensure_fetched(store)
         return self._current_state_ids
 
     @defer.inlineCallbacks
@@ -190,14 +157,7 @@ class EventContext:
                 Maps a (type, state_key) to the event ID of the state event matching
                 this tuple.
         """
-
-        if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(
-                self._fill_out_state, store
-            )
-
-        yield make_deferred_yieldable(self._fetching_state_deferred)
-
+        yield self._ensure_fetched(store)
         return self._prev_state_ids
 
     def get_cached_current_state_ids(self):
@@ -211,6 +171,44 @@ class EventContext:
 
         return self._current_state_ids
 
+    def _ensure_fetched(self, store):
+        return defer.succeed(None)
+
+
+@attr.s(slots=True)
+class _AsyncEventContextImpl(EventContext):
+    """
+    An implementation of EventContext which fetches _current_state_ids and
+    _prev_state_ids from the database on demand.
+
+    Attributes:
+
+        _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
+            been calculated. None if we haven't started calculating yet
+
+        _event_type (str): The type of the event the context is associated with.
+
+        _event_state_key (str): The state_key of the event the context is
+            associated with.
+
+        _prev_state_id (str|None): If the event associated with the context is
+            a state event, then `_prev_state_id` is the event_id of the state
+            that was replaced.
+    """
+
+    _prev_state_id = attr.ib(default=None)
+    _event_type = attr.ib(default=None)
+    _event_state_key = attr.ib(default=None)
+    _fetching_state_deferred = attr.ib(default=None)
+
+    def _ensure_fetched(self, store):
+        if not self._fetching_state_deferred:
+            self._fetching_state_deferred = run_in_background(
+                self._fill_out_state, store
+            )
+
+        return make_deferred_yieldable(self._fetching_state_deferred)
+
     @defer.inlineCallbacks
     def _fill_out_state(self, store):
         """Called to populate the _current_state_ids and _prev_state_ids
@@ -228,27 +226,6 @@ class EventContext:
         else:
             self._prev_state_ids = self._current_state_ids
 
-    @defer.inlineCallbacks
-    def update_state(
-        self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
-    ):
-        """Replace the state in the context
-        """
-
-        # We need to make sure we wait for any ongoing fetching of state
-        # to complete so that the updated state doesn't get clobbered
-        if self._fetching_state_deferred:
-            yield make_deferred_yieldable(self._fetching_state_deferred)
-
-        self.state_group = state_group
-        self._prev_state_ids = prev_state_ids
-        self.prev_group = prev_group
-        self._current_state_ids = current_state_ids
-        self.delta_ids = delta_ids
-
-        # We need to ensure that that we've marked as having fetched the state
-        self._fetching_state_deferred = defer.succeed(None)
-
 
 def _encode_state_dict(state_dict):
     """Since dicts of (type, state_key) -> event_id cannot be serialized in
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index dab6be9573..8cafcfdab0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.event_auth import auth_types_for_event
+from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.logging.context import (
     make_deferred_yieldable,
@@ -1871,14 +1872,7 @@ class FederationHandler(BaseHandler):
                 if c and c.type == EventTypes.Create:
                     auth_events[(c.type, c.state_key)] = c
 
-        try:
-            yield self.do_auth(origin, event, context, auth_events=auth_events)
-        except AuthError as e:
-            logger.warning(
-                "[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
-            )
-
-            context.rejected = RejectedReason.AUTH_ERROR
+        context = yield self.do_auth(origin, event, context, auth_events=auth_events)
 
         if not context.rejected:
             yield self._check_for_soft_fail(event, state, backfilled)
@@ -2047,12 +2041,12 @@ class FederationHandler(BaseHandler):
 
                 Also NB that this function adds entries to it.
         Returns:
-            defer.Deferred[None]
+            defer.Deferred[EventContext]: updated context object
         """
         room_version = yield self.store.get_room_version(event.room_id)
 
         try:
-            yield self._update_auth_events_and_context_for_auth(
+            context = yield self._update_auth_events_and_context_for_auth(
                 origin, event, context, auth_events
             )
         except Exception:
@@ -2070,7 +2064,9 @@ class FederationHandler(BaseHandler):
             event_auth.check(room_version, event, auth_events=auth_events)
         except AuthError as e:
             logger.warning("Failed auth resolution for %r because %s", event, e)
-            raise e
+            context.rejected = RejectedReason.AUTH_ERROR
+
+        return context
 
     @defer.inlineCallbacks
     def _update_auth_events_and_context_for_auth(
@@ -2094,7 +2090,7 @@ class FederationHandler(BaseHandler):
             auth_events (dict[(str, str)->synapse.events.EventBase]):
 
         Returns:
-            defer.Deferred[None]
+            defer.Deferred[EventContext]: updated context
         """
         event_auth_events = set(event.auth_event_ids())
 
@@ -2133,7 +2129,7 @@ class FederationHandler(BaseHandler):
                     # The other side isn't around or doesn't implement the
                     # endpoint, so lets just bail out.
                     logger.info("Failed to get event auth from remote: %s", e)
-                    return
+                    return context
 
                 seen_remotes = yield self.store.have_seen_events(
                     [e.event_id for e in remote_auth_chain]
@@ -2174,7 +2170,7 @@ class FederationHandler(BaseHandler):
 
         if event.internal_metadata.is_outlier():
             logger.info("Skipping auth_event fetch for outlier")
-            return
+            return context
 
         # FIXME: Assumes we have and stored all the state for all the
         # prev_events
@@ -2183,7 +2179,7 @@ class FederationHandler(BaseHandler):
         )
 
         if not different_auth:
-            return
+            return context
 
         logger.info(
             "auth_events refers to events which are not in our calculated auth "
@@ -2230,10 +2226,12 @@ class FederationHandler(BaseHandler):
 
             auth_events.update(new_state)
 
-            yield self._update_context_for_auth_events(
+            context = yield self._update_context_for_auth_events(
                 event, context, auth_events, event_key
             )
 
+        return context
+
     @defer.inlineCallbacks
     def _update_context_for_auth_events(self, event, context, auth_events, event_key):
         """Update the state_ids in an event context after auth event resolution,
@@ -2242,14 +2240,16 @@ class FederationHandler(BaseHandler):
         Args:
             event (Event): The event we're handling the context for
 
-            context (synapse.events.snapshot.EventContext): event context
-                to be updated
+            context (synapse.events.snapshot.EventContext): initial event context
 
             auth_events (dict[(str, str)->str]): Events to update in the event
                 context.
 
             event_key ((str, str)): (type, state_key) for the current event.
                 this will not be included in the current_state in the context.
+
+        Returns:
+            Deferred[EventContext]: new event context
         """
         state_updates = {
             k: a.event_id for k, a in iteritems(auth_events) if k != event_key
@@ -2274,7 +2274,7 @@ class FederationHandler(BaseHandler):
             current_state_ids=current_state_ids,
         )
 
-        yield context.update_state(
+        return EventContext.with_state(
             state_group=state_group,
             current_state_ids=current_state_ids,
             prev_state_ids=prev_state_ids,
diff --git a/tests/test_federation.py b/tests/test_federation.py
index d1acb16f30..7d82b58466 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         self.handler = self.homeserver.get_handlers().federation_handler
-        self.handler.do_auth = lambda *a, **b: succeed(True)
+        self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
+            context
+        )
         self.client = self.homeserver.get_federation_client()
         self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
             pdus
-- 
cgit 1.5.1


From 988d8d6507a0e8b34f2c352c77b5742197762190 Mon Sep 17 00:00:00 2001
From: Brendan Abolivier <babolivier@matrix.org>
Date: Fri, 1 Nov 2019 16:22:44 +0000
Subject: Incorporate review

---
 synapse/api/constants.py                                          | 2 +-
 synapse/api/filtering.py                                          | 2 +-
 synapse/storage/data_stores/main/events.py                        | 2 +-
 synapse/storage/data_stores/main/schema/delta/56/event_labels.sql | 6 ++++++
 tests/api/test_filtering.py                                       | 8 ++++----
 tests/rest/client/v1/test_rooms.py                                | 8 ++++----
 tests/rest/client/v2_alpha/test_sync.py                           | 8 ++++----
 7 files changed, 21 insertions(+), 15 deletions(-)

(limited to 'tests')

diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 066ce18704..49c4b85054 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -144,4 +144,4 @@ class EventContentFields(object):
     """Fields found in events' content, regardless of type."""
 
     # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
-    Labels = "org.matrix.labels"
+    LABELS = "org.matrix.labels"
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 30a7ee0a7a..bec13f08d8 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -309,7 +309,7 @@ class Filter(object):
             content = event.get("content", {})
             # check if there is a string url field in the content for filtering purposes
             contains_url = isinstance(content.get("url"), text_type)
-            labels = content.get(EventContentFields.Labels, [])
+            labels = content.get(EventContentFields.LABELS, [])
 
         return self.check_fields(room_id, sender, ev_type, labels, contains_url)
 
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 577e79bcf9..1045c7fa2e 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1491,7 +1491,7 @@ class EventsStore(
             self._handle_event_relations(txn, event)
 
             # Store the labels for this event.
-            labels = event.content.get(EventContentFields.Labels)
+            labels = event.content.get(EventContentFields.LABELS)
             if labels:
                 self.insert_labels_for_event_txn(
                     txn, event.event_id, labels, event.room_id, event.depth
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
index 2acd8e1be5..5e29c1da19 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
@@ -13,6 +13,8 @@
  * limitations under the License.
  */
 
+-- room_id and topoligical_ordering are denormalised from the events table in order to
+-- make the index work.
 CREATE TABLE IF NOT EXISTS event_labels (
     event_id TEXT,
     label TEXT,
@@ -21,4 +23,8 @@ CREATE TABLE IF NOT EXISTS event_labels (
     PRIMARY KEY(event_id, label)
 );
 
+
+-- This index enables an event pagination looking for a particular label to index the
+-- event_labels table first, which is much quicker than scanning the events table and then
+-- filtering by label, if the label is rarely used relative to the size of the room.
 CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering);
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 8ec48c4154..2dc5052249 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -329,7 +329,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={EventContentFields.Labels: ["#fun"]},
+            content={EventContentFields.LABELS: ["#fun"]},
         )
 
         self.assertTrue(Filter(definition).check(event))
@@ -338,7 +338,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={EventContentFields.Labels: ["#notfun"]},
+            content={EventContentFields.LABELS: ["#notfun"]},
         )
 
         self.assertFalse(Filter(definition).check(event))
@@ -349,7 +349,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={EventContentFields.Labels: ["#fun"]},
+            content={EventContentFields.LABELS: ["#fun"]},
         )
 
         self.assertFalse(Filter(definition).check(event))
@@ -358,7 +358,7 @@ class FilteringTestCase(unittest.TestCase):
             sender="@foo:bar",
             type="m.room.message",
             room_id="!secretbase:unknown",
-            content={EventContentFields.Labels: ["#notfun"]},
+            content={EventContentFields.LABELS: ["#notfun"]},
         )
 
         self.assertTrue(Filter(definition).check(event))
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 0dc0faa0e5..5e38fd6ced 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -860,7 +860,7 @@ class RoomMessageListTestCase(RoomBase):
             content={
                 "msgtype": "m.text",
                 "body": "with right label",
-                EventContentFields.Labels: ["#fun"],
+                EventContentFields.LABELS: ["#fun"],
             },
         )
 
@@ -876,7 +876,7 @@ class RoomMessageListTestCase(RoomBase):
             content={
                 "msgtype": "m.text",
                 "body": "with wrong label",
-                EventContentFields.Labels: ["#work"],
+                EventContentFields.LABELS: ["#work"],
             },
         )
 
@@ -886,7 +886,7 @@ class RoomMessageListTestCase(RoomBase):
             content={
                 "msgtype": "m.text",
                 "body": "with two wrong labels",
-                EventContentFields.Labels: ["#work", "#notfun"],
+                EventContentFields.LABELS: ["#work", "#notfun"],
             },
         )
 
@@ -896,7 +896,7 @@ class RoomMessageListTestCase(RoomBase):
             content={
                 "msgtype": "m.text",
                 "body": "with right label",
-                EventContentFields.Labels: ["#fun"],
+                EventContentFields.LABELS: ["#fun"],
             },
         )
 
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index c3c6f75ced..3283c0e47b 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -157,7 +157,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             content={
                 "msgtype": "m.text",
                 "body": "with right label",
-                EventContentFields.Labels: ["#fun"],
+                EventContentFields.LABELS: ["#fun"],
             },
             tok=tok,
         )
@@ -175,7 +175,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             content={
                 "msgtype": "m.text",
                 "body": "with wrong label",
-                EventContentFields.Labels: ["#work"],
+                EventContentFields.LABELS: ["#work"],
             },
             tok=tok,
         )
@@ -186,7 +186,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             content={
                 "msgtype": "m.text",
                 "body": "with two wrong labels",
-                EventContentFields.Labels: ["#work", "#notfun"],
+                EventContentFields.LABELS: ["#work", "#notfun"],
             },
             tok=tok,
         )
@@ -197,7 +197,7 @@ class SyncFilterTestCase(unittest.HomeserverTestCase):
             content={
                 "msgtype": "m.text",
                 "body": "with right label",
-                EventContentFields.Labels: ["#fun"],
+                EventContentFields.LABELS: ["#fun"],
             },
             tok=tok,
         )
-- 
cgit 1.5.1