diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 9c55e79ef5..942d7c721f 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -78,6 +78,9 @@ CONDITIONAL_REQUIREMENTS = {
"affinity": {
"affinity": ["affinity"],
},
+ "postgres": {
+ "psycopg2>=2.6": ["psycopg2"]
+ }
}
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 08dffd774f..be61147b9b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -17,9 +17,10 @@ import sys
import threading
import time
-from six import iteritems, iterkeys, itervalues
+from six import PY2, iteritems, iterkeys, itervalues
from six.moves import intern, range
+from canonicaljson import json
from prometheus_client import Histogram
from twisted.internet import defer
@@ -1216,3 +1217,32 @@ class _RollbackButIsFineException(Exception):
something went wrong.
"""
pass
+
+
+def db_to_json(db_content):
+ """
+ Take some data from a database row and return a JSON-decoded object.
+
+ Args:
+ db_content (memoryview|buffer|bytes|bytearray|unicode)
+ """
+ # psycopg2 on Python 3 returns memoryview objects, which we need to
+ # cast to bytes to decode
+ if isinstance(db_content, memoryview):
+ db_content = db_content.tobytes()
+
+ # psycopg2 on Python 2 returns buffer objects, which we need to cast to
+ # bytes to decode
+ if PY2 and isinstance(db_content, buffer):
+ db_content = bytes(db_content)
+
+ # Decode it to a Unicode string before feeding it to json.loads, so we
+ # consistenty get a Unicode-containing object out.
+ if isinstance(db_content, (bytes, bytearray)):
+ db_content = db_content.decode('utf8')
+
+ try:
+ return json.loads(db_content)
+ except Exception:
+ logging.warning("Tried to decode '%r' as JSON and failed", db_content)
+ raise
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 73646da025..e06b0bc56d 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -169,7 +169,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
- devices = messages_by_device.keys()
+ devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
sql = (
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index c0943ecf91..d10ff9e4b9 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -24,7 +24,7 @@ from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
-from ._base import Cache, SQLBaseStore
+from ._base import Cache, SQLBaseStore, db_to_json
logger = logging.getLogger(__name__)
@@ -411,7 +411,7 @@ class DeviceStore(SQLBaseStore):
if device is not None:
key_json = device.get("key_json", None)
if key_json:
- result["keys"] = json.loads(key_json)
+ result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -466,7 +466,7 @@ class DeviceStore(SQLBaseStore):
retcol="content",
desc="_get_cached_user_device",
)
- defer.returnValue(json.loads(content))
+ defer.returnValue(db_to_json(content))
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
@@ -479,7 +479,7 @@ class DeviceStore(SQLBaseStore):
desc="_get_cached_devices_for_user",
)
defer.returnValue({
- device["device_id"]: json.loads(device["content"])
+ device["device_id"]: db_to_json(device["content"])
for device in devices
})
@@ -511,7 +511,7 @@ class DeviceStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
- result["keys"] = json.loads(key_json)
+ result["keys"] = db_to_json(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 523b4360c3..1f1721e820 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,13 +14,13 @@
# limitations under the License.
from six import iteritems
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.util.caches.descriptors import cached
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
class EndToEndKeyStore(SQLBaseStore):
@@ -90,7 +90,7 @@ class EndToEndKeyStore(SQLBaseStore):
for user_id, device_keys in iteritems(results):
for device_id, device_info in iteritems(device_keys):
- device_info["keys"] = json.loads(device_info.pop("key_json"))
+ device_info["keys"] = db_to_json(device_info.pop("key_json"))
defer.returnValue(results)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 8a0386c1a4..42225f8a2a 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -41,13 +41,18 @@ class PostgresEngine(object):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
+
+ # Set the bytea output to escape, vs the default of hex
+ cursor = db_conn.cursor()
+ cursor.execute("SET bytea_output TO escape")
+
# Asynchronous commit, don't wait for the server to call fsync before
# ending the transaction.
# https://www.postgresql.org/docs/current/static/wal-async-commit.html
if not self.synchronous_commit:
- cursor = db_conn.cursor()
cursor.execute("SET synchronous_commit TO OFF")
- cursor.close()
+
+ cursor.close()
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index f39c8c8461..8bf87f38f7 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -19,7 +19,7 @@ import logging
from collections import OrderedDict, deque, namedtuple
from functools import wraps
-from six import iteritems
+from six import iteritems, text_type
from six.moves import range
from canonicaljson import json
@@ -1220,7 +1220,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
"sender": event.sender,
"contains_url": (
"url" in event.content
- and isinstance(event.content["url"], basestring)
+ and isinstance(event.content["url"], text_type)
),
}
for event, _ in events_and_contexts
@@ -1529,7 +1529,7 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
contains_url = "url" in content
if contains_url:
- contains_url &= isinstance(content["url"], basestring)
+ contains_url &= isinstance(content["url"], text_type)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -1910,9 +1910,9 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
(room_id,)
)
rows = txn.fetchall()
- max_depth = max(row[0] for row in rows)
+ max_depth = max(row[1] for row in rows)
- if max_depth <= token.topological:
+ if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 59822178ff..a8326f5296 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import itertools
import logging
from collections import namedtuple
@@ -265,7 +266,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
with Measure(self._clock, "_fetch_event_list"):
try:
- event_id_lists = zip(*event_list)[0]
+ event_id_lists = list(zip(*event_list))[0]
event_ids = [
item for sublist in event_id_lists for item in sublist
]
@@ -299,14 +300,14 @@ class EventsWorkerStore(SQLBaseStore):
logger.exception("do_fetch")
# We only want to resolve deferreds from the main thread
- def fire(evs):
+ def fire(evs, exc):
for _, d in evs:
if not d.called:
with PreserveLoggingContext():
- d.errback(e)
+ d.errback(exc)
with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list)
+ self.hs.get_reactor().callFromThread(fire, event_list, e)
@defer.inlineCallbacks
def _enqueue_events(self, events, check_redacted=True, allow_rejected=False):
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 2d5896c5b4..6ddcc909bf 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -13,14 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
class FilteringStore(SQLBaseStore):
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
+ defer.returnValue(db_to_json(def_json))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 8443bd4c1b..c7987bfcdd 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -15,7 +15,8 @@
# limitations under the License.
import logging
-import types
+
+import six
from canonicaljson import encode_canonical_json, json
@@ -27,6 +28,11 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
+
class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
@@ -34,18 +40,18 @@ class PusherWorkerStore(SQLBaseStore):
dataJson = r['data']
r['data'] = None
try:
- if isinstance(dataJson, types.BufferType):
+ if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")
r['data'] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
- r['id'], dataJson, e.message,
+ r['id'], dataJson, e.args[0],
)
pass
- if isinstance(r['pushkey'], types.BufferType):
+ if isinstance(r['pushkey'], db_binary_type):
r['pushkey'] = str(r['pushkey']).decode("UTF8")
return rows
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 428e7fa36e..0c42bd3322 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -18,14 +18,14 @@ from collections import namedtuple
import six
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util.caches.descriptors import cached
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
@@ -95,7 +95,8 @@ class TransactionStore(SQLBaseStore):
)
if result and result["response_code"]:
- return result["response_code"], json.loads(str(result["response_json"]))
+ return result["response_code"], db_to_json(result["response_json"])
+
else:
return None
|