diff --git a/crypto/src/crypto/tls/DtlsRecordLayer.cs b/crypto/src/crypto/tls/DtlsRecordLayer.cs
index 3cb0e78dd..266893df0 100644
--- a/crypto/src/crypto/tls/DtlsRecordLayer.cs
+++ b/crypto/src/crypto/tls/DtlsRecordLayer.cs
@@ -1,5 +1,6 @@
using System;
using System.IO;
+using System.Net.Sockets;
using Org.BouncyCastle.Utilities.Date;
@@ -13,6 +14,21 @@ namespace Org.BouncyCastle.Crypto.Tls
private const long TCP_MSL = 1000L * 60 * 2;
private const long RETRANSMIT_TIMEOUT = TCP_MSL * 2;
+ private static void SendDatagram(DatagramTransport sender, byte[] buf, int off, int len)
+ {
+ //try
+ //{
+ // sender.Send(buf, off, len);
+ //}
+ //catch (InterruptedIOException e)
+ //{
+ // e.bytesTransferred = 0;
+ // throw e;
+ //}
+
+ sender.Send(buf, off, len);
+ }
+
private readonly DatagramTransport mTransport;
private readonly TlsContext mContext;
private readonly TlsPeer mPeer;
@@ -134,6 +150,8 @@ namespace Org.BouncyCastle.Crypto.Tls
public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
{
+ // TODO Avoid returning -1 (timeout) until 'waitMillis' has definitely elapsed
+
byte[] record = null;
for (;;)
@@ -144,191 +162,183 @@ namespace Org.BouncyCastle.Crypto.Tls
record = new byte[receiveLimit];
}
- try
+ if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry)
{
- if (mRetransmit != null && DateTimeUtilities.CurrentUnixMs() > mRetransmitExpiry)
- {
- mRetransmit = null;
- mRetransmitEpoch = null;
- }
+ mRetransmit = null;
+ mRetransmitEpoch = null;
+ }
- int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
- if (received < 0)
- {
- return received;
- }
- if (received < RECORD_HEADER_LENGTH)
- {
- continue;
- }
- int length = TlsUtilities.ReadUint16(record, 11);
- if (received != (length + RECORD_HEADER_LENGTH))
- {
- continue;
- }
+ int received = ReceiveRecord(record, 0, receiveLimit, waitMillis);
+ if (received < 0)
+ {
+ return received;
+ }
+ if (received < RECORD_HEADER_LENGTH)
+ {
+ continue;
+ }
+ int length = TlsUtilities.ReadUint16(record, 11);
+ if (received != (length + RECORD_HEADER_LENGTH))
+ {
+ continue;
+ }
- byte type = TlsUtilities.ReadUint8(record, 0);
+ byte type = TlsUtilities.ReadUint8(record, 0);
- // TODO Support user-specified custom protocols?
- switch (type)
- {
- case ContentType.alert:
- case ContentType.application_data:
- case ContentType.change_cipher_spec:
- case ContentType.handshake:
- case ContentType.heartbeat:
- break;
- default:
- // TODO Exception?
- continue;
- }
+ // TODO Support user-specified custom protocols?
+ switch (type)
+ {
+ case ContentType.alert:
+ case ContentType.application_data:
+ case ContentType.change_cipher_spec:
+ case ContentType.handshake:
+ case ContentType.heartbeat:
+ break;
+ default:
+ // TODO Exception?
+ continue;
+ }
- int epoch = TlsUtilities.ReadUint16(record, 3);
+ int epoch = TlsUtilities.ReadUint16(record, 3);
- DtlsEpoch recordEpoch = null;
- if (epoch == mReadEpoch.Epoch)
- {
- recordEpoch = mReadEpoch;
- }
- else if (type == ContentType.handshake && mRetransmitEpoch != null
- && epoch == mRetransmitEpoch.Epoch)
- {
- recordEpoch = mRetransmitEpoch;
- }
+ DtlsEpoch recordEpoch = null;
+ if (epoch == mReadEpoch.Epoch)
+ {
+ recordEpoch = mReadEpoch;
+ }
+ else if (type == ContentType.handshake && mRetransmitEpoch != null
+ && epoch == mRetransmitEpoch.Epoch)
+ {
+ recordEpoch = mRetransmitEpoch;
+ }
- if (recordEpoch == null)
- {
- continue;
- }
+ if (recordEpoch == null)
+ {
+ continue;
+ }
- long seq = TlsUtilities.ReadUint48(record, 5);
- if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
- {
- continue;
- }
+ long seq = TlsUtilities.ReadUint48(record, 5);
+ if (recordEpoch.ReplayWindow.ShouldDiscard(seq))
+ {
+ continue;
+ }
- ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
- if (!version.IsDtls)
- {
- continue;
- }
+ ProtocolVersion version = TlsUtilities.ReadVersion(record, 1);
+ if (!version.IsDtls)
+ {
+ continue;
+ }
- if (mReadVersion != null && !mReadVersion.Equals(version))
- {
- continue;
- }
+ if (mReadVersion != null && !mReadVersion.Equals(version))
+ {
+ continue;
+ }
- byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext(
- GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH,
- received - RECORD_HEADER_LENGTH);
+ byte[] plaintext = recordEpoch.Cipher.DecodeCiphertext(
+ GetMacSequenceNumber(recordEpoch.Epoch, seq), type, record, RECORD_HEADER_LENGTH,
+ received - RECORD_HEADER_LENGTH);
- recordEpoch.ReplayWindow.ReportAuthenticated(seq);
+ recordEpoch.ReplayWindow.ReportAuthenticated(seq);
- if (plaintext.Length > this.mPlaintextLimit)
- {
- continue;
- }
+ if (plaintext.Length > this.mPlaintextLimit)
+ {
+ continue;
+ }
- if (mReadVersion == null)
- {
- mReadVersion = version;
- }
+ if (mReadVersion == null)
+ {
+ mReadVersion = version;
+ }
- switch (type)
- {
- case ContentType.alert:
+ switch (type)
+ {
+ case ContentType.alert:
+ {
+ if (plaintext.Length == 2)
{
- if (plaintext.Length == 2)
+ byte alertLevel = plaintext[0];
+ byte alertDescription = plaintext[1];
+
+ mPeer.NotifyAlertReceived(alertLevel, alertDescription);
+
+ if (alertLevel == AlertLevel.fatal)
{
- byte alertLevel = plaintext[0];
- byte alertDescription = plaintext[1];
-
- mPeer.NotifyAlertReceived(alertLevel, alertDescription);
-
- if (alertLevel == AlertLevel.fatal)
- {
- Failed();
- throw new TlsFatalAlert(alertDescription);
- }
-
- // TODO Can close_notify be a fatal alert?
- if (alertDescription == AlertDescription.close_notify)
- {
- CloseTransport();
- }
+ Failed();
+ throw new TlsFatalAlert(alertDescription);
}
+ // TODO Can close_notify be a fatal alert?
+ if (alertDescription == AlertDescription.close_notify)
+ {
+ CloseTransport();
+ }
+ }
+
+ continue;
+ }
+ case ContentType.application_data:
+ {
+ if (mInHandshake)
+ {
+ // TODO Consider buffering application data for new epoch that arrives
+ // out-of-order with the Finished message
continue;
}
- case ContentType.application_data:
+ break;
+ }
+ case ContentType.change_cipher_spec:
+ {
+ // Implicitly receive change_cipher_spec and change to pending cipher state
+
+ for (int i = 0; i < plaintext.Length; ++i)
{
- if (mInHandshake)
+ byte message = TlsUtilities.ReadUint8(plaintext, i);
+ if (message != ChangeCipherSpec.change_cipher_spec)
{
- // TODO Consider buffering application data for new epoch that arrives
- // out-of-order with the Finished message
continue;
}
- break;
- }
- case ContentType.change_cipher_spec:
- {
- // Implicitly receive change_cipher_spec and change to pending cipher state
- for (int i = 0; i < plaintext.Length; ++i)
+ if (mPendingEpoch != null)
{
- byte message = TlsUtilities.ReadUint8(plaintext, i);
- if (message != ChangeCipherSpec.change_cipher_spec)
- {
- continue;
- }
-
- if (mPendingEpoch != null)
- {
- mReadEpoch = mPendingEpoch;
- }
+ mReadEpoch = mPendingEpoch;
}
-
- continue;
}
- case ContentType.handshake:
+
+ continue;
+ }
+ case ContentType.handshake:
+ {
+ if (!mInHandshake)
{
- if (!mInHandshake)
+ if (mRetransmit != null)
{
- if (mRetransmit != null)
- {
- mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length);
- }
-
- // TODO Consider support for HelloRequest
- continue;
+ mRetransmit.ReceivedHandshakeRecord(epoch, plaintext, 0, plaintext.Length);
}
- break;
- }
- case ContentType.heartbeat:
- {
- // TODO[RFC 6520]
- continue;
- }
- }
- /*
- * NOTE: If we receive any non-handshake data in the new epoch implies the peer has
- * received our final flight.
- */
- if (!mInHandshake && mRetransmit != null)
- {
- this.mRetransmit = null;
- this.mRetransmitEpoch = null;
+ // TODO Consider support for HelloRequest
+ continue;
}
-
- Array.Copy(plaintext, 0, buf, off, plaintext.Length);
- return plaintext.Length;
+ break;
}
- catch (IOException e)
+ case ContentType.heartbeat:
{
- // NOTE: Assume this is a timeout for the moment
- throw e;
+ // TODO[RFC 6520]
+ continue;
+ }
}
+
+ /*
+ * NOTE: If we receive any non-handshake data in the new epoch implies the peer has
+ * received our final flight.
+ */
+ if (!mInHandshake && mRetransmit != null)
+ {
+ this.mRetransmit = null;
+ this.mRetransmitEpoch = null;
+ }
+
+ Array.Copy(plaintext, 0, buf, off, plaintext.Length);
+ return plaintext.Length;
}
}
@@ -458,6 +468,35 @@ namespace Org.BouncyCastle.Crypto.Tls
SendRecord(ContentType.alert, error, 0, 2);
}
+ private int ReceiveDatagram(byte[] buf, int off, int len, int waitMillis)
+ {
+ //try
+ //{
+ // return mTransport.Receive(buf, off, len, waitMillis);
+ //}
+ //catch (SocketTimeoutException e)
+ //{
+ // return -1;
+ //}
+ //catch (InterruptedIOException e)
+ //{
+ // e.bytesTransferred = 0;
+ // throw e;
+ //}
+
+ try
+ {
+ return mTransport.Receive(buf, off, len, waitMillis);
+ }
+ catch (SocketException e)
+ {
+ if (TlsUtilities.IsTimeout(e))
+ return -1;
+
+ throw e;
+ }
+ }
+
private int ReceiveRecord(byte[] buf, int off, int len, int waitMillis)
{
if (mRecordQueue.Available > 0)
@@ -476,7 +515,7 @@ namespace Org.BouncyCastle.Crypto.Tls
}
{
- int received = mTransport.Receive(buf, off, len, waitMillis);
+ int received = ReceiveDatagram(buf, off, len, waitMillis);
if (received >= RECORD_HEADER_LENGTH)
{
int fragmentLength = TlsUtilities.ReadUint16(buf, off + 11);
@@ -524,7 +563,7 @@ namespace Org.BouncyCastle.Crypto.Tls
TlsUtilities.WriteUint16(ciphertext.Length, record, 11);
Array.Copy(ciphertext, 0, record, RECORD_HEADER_LENGTH, ciphertext.Length);
- mTransport.Send(record, 0, record.Length);
+ SendDatagram(mTransport, record, 0, record.Length);
}
private static long GetMacSequenceNumber(int epoch, long sequence_number)
diff --git a/crypto/src/crypto/tls/DtlsReliableHandshake.cs b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
index 8fcc1d7c2..92c222e70 100644
--- a/crypto/src/crypto/tls/DtlsReliableHandshake.cs
+++ b/crypto/src/crypto/tls/DtlsReliableHandshake.cs
@@ -76,6 +76,8 @@ namespace Org.BouncyCastle.Crypto.Tls
internal Message ReceiveMessage()
{
+ // TODO Add support for "overall" handshake timeout
+
if (mSending)
{
mSending = false;
@@ -89,41 +91,37 @@ namespace Org.BouncyCastle.Crypto.Tls
for (;;)
{
- try
+ if (mRecordLayer.IsClosed)
+ throw new TlsFatalAlert(AlertDescription.user_canceled);
+
+ Message pending = GetPendingMessage();
+ if (pending != null)
+ return pending;
+
+ int receiveLimit = mRecordLayer.GetReceiveLimit();
+ if (buf == null || buf.Length < receiveLimit)
{
- for (;;)
- {
- if (mRecordLayer.IsClosed)
- throw new TlsFatalAlert(AlertDescription.user_canceled);
-
- Message pending = GetPendingMessage();
- if (pending != null)
- return pending;
-
- int receiveLimit = mRecordLayer.GetReceiveLimit();
- if (buf == null || buf.Length < receiveLimit)
- {
- buf = new byte[receiveLimit];
- }
-
- int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis);
- if (received < 0)
- break;
-
- bool resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
- if (resentOutbound)
- {
- readTimeoutMillis = BackOff(readTimeoutMillis);
- }
- }
+ buf = new byte[receiveLimit];
+ }
+
+ int received = mRecordLayer.Receive(buf, 0, receiveLimit, readTimeoutMillis);
+
+ bool resentOutbound;
+ if (received < 0)
+ {
+ ResendOutboundFlight();
+ resentOutbound = true;
}
- catch (IOException)
+ else
{
- // NOTE: Assume this is a timeout for the moment
+ resentOutbound = ProcessRecord(MaxReceiveAhead, mRecordLayer.ReadEpoch, buf, 0, received);
}
- ResendOutboundFlight();
- readTimeoutMillis = BackOff(readTimeoutMillis);
+ // TODO Review conditions for resend/backoff
+ if (resentOutbound)
+ {
+ readTimeoutMillis = BackOff(readTimeoutMillis);
+ }
}
}
diff --git a/crypto/src/crypto/tls/DtlsTransport.cs b/crypto/src/crypto/tls/DtlsTransport.cs
index 5c607336b..bc09707c1 100644
--- a/crypto/src/crypto/tls/DtlsTransport.cs
+++ b/crypto/src/crypto/tls/DtlsTransport.cs
@@ -1,5 +1,6 @@
using System;
using System.IO;
+using System.Net.Sockets;
namespace Org.BouncyCastle.Crypto.Tls
{
@@ -25,6 +26,15 @@ namespace Org.BouncyCastle.Crypto.Tls
public virtual int Receive(byte[] buf, int off, int len, int waitMillis)
{
+ if (null == buf)
+ throw new ArgumentNullException("buf");
+ if (off < 0 || off >= buf.Length)
+ throw new ArgumentException("invalid offset: " + off, "off");
+ if (len < 0 || len > buf.Length - off)
+ throw new ArgumentException("invalid length: " + len, "len");
+ if (waitMillis < 0)
+ throw new ArgumentException("cannot be negative", "waitMillis");
+
try
{
return mRecordLayer.Receive(buf, off, len, waitMillis);
@@ -34,11 +44,23 @@ namespace Org.BouncyCastle.Crypto.Tls
mRecordLayer.Fail(fatalAlert.AlertDescription);
throw fatalAlert;
}
+ //catch (InterruptedIOException e)
+ //{
+ // throw e;
+ //}
catch (IOException e)
{
mRecordLayer.Fail(AlertDescription.internal_error);
throw e;
}
+ catch (SocketException e)
+ {
+ if (TlsUtilities.IsTimeout(e))
+ throw e;
+
+ mRecordLayer.Fail(AlertDescription.internal_error);
+ throw new TlsFatalAlert(AlertDescription.internal_error, e);
+ }
catch (Exception e)
{
mRecordLayer.Fail(AlertDescription.internal_error);
@@ -48,6 +70,13 @@ namespace Org.BouncyCastle.Crypto.Tls
public virtual void Send(byte[] buf, int off, int len)
{
+ if (null == buf)
+ throw new ArgumentNullException("buf");
+ if (off < 0 || off >= buf.Length)
+ throw new ArgumentException("invalid offset: " + off, "off");
+ if (len < 0 || len > buf.Length - off)
+ throw new ArgumentException("invalid length: " + len, "len");
+
try
{
mRecordLayer.Send(buf, off, len);
@@ -57,11 +86,23 @@ namespace Org.BouncyCastle.Crypto.Tls
mRecordLayer.Fail(fatalAlert.AlertDescription);
throw fatalAlert;
}
+ //catch (InterruptedIOException e)
+ //{
+ // throw e;
+ //}
catch (IOException e)
{
mRecordLayer.Fail(AlertDescription.internal_error);
throw e;
}
+ catch (SocketException e)
+ {
+ if (TlsUtilities.IsTimeout(e))
+ throw e;
+
+ mRecordLayer.Fail(AlertDescription.internal_error);
+ throw new TlsFatalAlert(AlertDescription.internal_error, e);
+ }
catch (Exception e)
{
mRecordLayer.Fail(AlertDescription.internal_error);
diff --git a/crypto/src/crypto/tls/TlsUtilities.cs b/crypto/src/crypto/tls/TlsUtilities.cs
index 6ee71021f..5aad6b0a1 100644
--- a/crypto/src/crypto/tls/TlsUtilities.cs
+++ b/crypto/src/crypto/tls/TlsUtilities.cs
@@ -1,5 +1,6 @@
using System;
using System.Collections;
+using System.Net.Sockets;
using System.IO;
using System.Text;
@@ -2345,5 +2346,12 @@ namespace Org.BouncyCastle.Crypto.Tls
}
return v;
}
+
+ public static bool IsTimeout(SocketException e)
+ {
+ // TODO Net 2.0+
+ //return SocketError.TimedOut == e.SocketErrorCode;
+ return 10060 == e.ErrorCode;
+ }
}
}
|