summary refs log tree commit diff
path: root/crypto/src/asn1/DerBitString.cs
diff options
context:
space:
mode:
Diffstat (limited to 'crypto/src/asn1/DerBitString.cs')
-rw-r--r--crypto/src/asn1/DerBitString.cs242
1 files changed, 147 insertions, 95 deletions
diff --git a/crypto/src/asn1/DerBitString.cs b/crypto/src/asn1/DerBitString.cs
index 4dabb398f..caf5d6f9f 100644
--- a/crypto/src/asn1/DerBitString.cs
+++ b/crypto/src/asn1/DerBitString.cs
@@ -13,16 +13,12 @@ namespace Org.BouncyCastle.Asn1
 		private static readonly char[] table
 			= { '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F' };
 
-		protected readonly byte[]   mData;
-		protected readonly int      mPadBits;
-
         /**
 		 * return a Bit string from the passed in object
 		 *
 		 * @exception ArgumentException if the object cannot be converted.
 		 */
-		public static DerBitString GetInstance(
-			object obj)
+		public static DerBitString GetInstance(object obj)
 		{
 			if (obj == null || obj is DerBitString)
 			{
@@ -52,9 +48,7 @@ namespace Org.BouncyCastle.Asn1
 		 * @exception ArgumentException if the tagged object cannot
 		 *               be converted.
 		 */
-		public static DerBitString GetInstance(
-			Asn1TaggedObject	obj,
-			bool				isExplicit)
+		public static DerBitString GetInstance(Asn1TaggedObject obj, bool isExplicit)
 		{
 			Asn1Object o = obj.GetObject();
 
@@ -63,16 +57,30 @@ namespace Org.BouncyCastle.Asn1
 				return GetInstance(o);
 			}
 
-			return FromAsn1Octets(((Asn1OctetString)o).GetOctets());
+            // Not copied because assumed to be a tagged implicit primitive from the parser
+			return CreatePrimitive(((Asn1OctetString)o).GetOctets());
 		}
 
+        internal readonly byte[] contents;
+
+        public DerBitString(byte data, int padBits)
+        {
+            if (padBits > 7 || padBits < 0)
+                throw new ArgumentException("pad bits cannot be greater than 7 or less than 0", "padBits");
+
+            this.contents = new byte[] { (byte)padBits, data };
+        }
+
+        public DerBitString(byte[] data)
+            : this(data, 0)
+        {
+        }
+
         /**
 		 * @param data the octets making up the bit string.
 		 * @param padBits the number of extra bits at the end of the string.
 		 */
-		public DerBitString(
-			byte[]	data,
-			int		padBits)
+        public DerBitString(byte[] data, int padBits)
 		{
             if (data == null)
                 throw new ArgumentNullException("data");
@@ -81,42 +89,30 @@ namespace Org.BouncyCastle.Asn1
             if (data.Length == 0 && padBits != 0)
                 throw new ArgumentException("if 'data' is empty, 'padBits' must be 0");
 
-            this.mData = Arrays.Clone(data);
-			this.mPadBits = padBits;
-		}
-
-		public DerBitString(
-			byte[] data)
-            : this(data, 0)
-		{
-		}
+            this.contents = Arrays.Prepend(data, (byte)padBits);
+        }
 
-        public DerBitString(
-            int namedBits)
+        public DerBitString(int namedBits)
         {
             if (namedBits == 0)
             {
-                this.mData = new byte[0];
-                this.mPadBits = 0;
+                this.contents = new byte[]{ 0 };
                 return;
             }
 
             int bits = BigInteger.BitLen(namedBits);
             int bytes = (bits + 7) / 8;
-
             Debug.Assert(0 < bytes && bytes <= 4);
 
-            byte[] data = new byte[bytes];
-            --bytes;
+            byte[] data = new byte[1 + bytes];
 
-            for (int i = 0; i < bytes; i++)
+            for (int i = 1; i < bytes; i++)
             {
                 data[i] = (byte)namedBits;
                 namedBits >>= 8;
             }
 
             Debug.Assert((namedBits & 0xFF) != 0);
-
             data[bytes] = (byte)namedBits;
 
             int padBits = 0;
@@ -126,17 +122,38 @@ namespace Org.BouncyCastle.Asn1
             }
 
             Debug.Assert(padBits < 8);
+            data[0] = (byte)padBits;
 
-            this.mData = data;
-            this.mPadBits = padBits;
+            this.contents = data;
         }
 
-        public DerBitString(
-			Asn1Encodable obj)
+        public DerBitString(Asn1Encodable obj)
             : this(obj.GetDerEncoded())
 		{
 		}
 
+        internal DerBitString(byte[] contents, bool check)
+        {
+            if (check)
+            {
+                if (null == contents)
+                    throw new ArgumentNullException("contents");
+                if (contents.Length < 1)
+                    throw new ArgumentException("cannot be empty", "contents");
+
+                int padBits = contents[0];
+                if (padBits > 0)
+                {
+                    if (contents.Length < 2)
+                        throw new ArgumentException("zero length data with non-zero pad bits", "contents");
+                    if (padBits > 7)
+                        throw new ArgumentException("pad bits cannot be greater than 7 or less than 0", "contents");
+                }
+            }
+
+            this.contents = contents;
+        }
+
         /**
          * Return the octets contained in this BIT STRING, checking that this BIT STRING really
          * does represent an octet aligned string. Only use this method when the standard you are
@@ -146,28 +163,27 @@ namespace Org.BouncyCastle.Asn1
          */
         public virtual byte[] GetOctets()
         {
-            if (mPadBits != 0)
+            if (contents[0] != 0)
                 throw new InvalidOperationException("attempt to get non-octet aligned data from BIT STRING");
 
-            return Arrays.Clone(mData);
+            return Arrays.CopyOfRange(contents, 1, contents.Length);
         }
 
         public virtual byte[] GetBytes()
 		{
-            byte[] data = Arrays.Clone(mData);
+            if (contents.Length == 1)
+                return Asn1OctetString.EmptyOctets;
 
+            int padBits = contents[0];
+            byte[] rv = Arrays.CopyOfRange(contents, 1, contents.Length);
             // DER requires pad bits be zero
-            if (mPadBits > 0)
-            {
-                data[data.Length - 1] &= (byte)(0xFF << mPadBits);
-            }
-
-            return data;
-		}
+            rv[rv.Length - 1] &= (byte)(0xFF << padBits);
+            return rv;
+        }
 
         public virtual int PadBits
 		{
-			get { return mPadBits; }
+			get { return contents[0]; }
 		}
 
 		/**
@@ -177,68 +193,90 @@ namespace Org.BouncyCastle.Asn1
 		{
 			get
 			{
-                int value = 0, length = System.Math.Min(4, mData.Length);
-                for (int i = 0; i < length; ++i)
+                int value = 0, end = System.Math.Min(5, contents.Length - 1);
+                for (int i = 1; i < end; ++i)
                 {
-                    value |= (int)mData[i] << (8 * i);
+                    value |= (int)contents[i] << (8 * (i - 1));
                 }
-                if (mPadBits > 0 && length == mData.Length)
+                if (1 <= end && end < 5)
                 {
-                    int mask = (1 << mPadBits) - 1;
-                    value &= ~(mask << (8 * (length - 1)));
+                    int padBits = contents[0];
+                    byte der = (byte)(contents[end] & (0xFF << padBits));
+                    value |= (int)der << (8 * (end - 1));
                 }
                 return value;
-			}
+            }
 		}
 
         internal override int EncodedLength(bool withID)
         {
-            return Asn1OutputStream.GetLengthOfEncodingDL(withID, 1 + mData.Length);
+            return Asn1OutputStream.GetLengthOfEncodingDL(withID, contents.Length);
         }
 
         internal override void Encode(Asn1OutputStream asn1Out, bool withID)
 		{
-            if (mPadBits > 0)
-            {
-                int last = mData[mData.Length - 1];
-                int mask = (1 << mPadBits) - 1;
-                int unusedBits = last & mask;
-
-                if (unusedBits != 0)
-                {
-                    byte[] contents = Arrays.Prepend(mData, (byte)mPadBits);
+            int padBits = contents[0];
+            int length = contents.Length;
+            int last = length - 1;
 
-                    /*
-                     * X.690-0207 11.2.1: Each unused bit in the final octet of the encoding of a bit string value shall be set to zero.
-                     */
-                    contents[contents.Length - 1] = (byte)(last ^ unusedBits);
+            byte lastOctet = contents[last];
+            byte lastOctetDer = (byte)(contents[last] & (0xFF << padBits));
 
-                    asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, contents);
-                    return;
-                }
+            if (lastOctet == lastOctetDer)
+            {
+                asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, contents);
+            }
+            else
+            {
+                asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, contents, 0, last, lastOctetDer);
             }
-
-            asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, (byte)mPadBits, mData, 0, mData.Length);
 		}
 
         protected override int Asn1GetHashCode()
 		{
-			return mPadBits.GetHashCode() ^ Arrays.GetHashCode(mData);
-		}
+            if (contents.Length < 2)
+                return 1;
+
+            int padBits = contents[0];
+            int last = contents.Length - 1;
+
+            byte lastOctetDer = (byte)(contents[last] & (0xFF << padBits));
 
-		protected override bool Asn1Equals(
-			Asn1Object asn1Object)
+            int hc = Arrays.GetHashCode(contents, 0, last);
+            hc *= 257;
+            hc ^= lastOctetDer;
+            return hc;
+        }
+
+        protected override bool Asn1Equals(Asn1Object asn1Object)
 		{
-			DerBitString other = asn1Object as DerBitString;
+            DerBitString that = asn1Object as DerBitString;
+            if (null == that)
+                return false;
 
-			if (other == null)
-				return false;
+            byte[] thisContents = this.contents, thatContents = that.contents;
 
-			return this.mPadBits == other.mPadBits
-				&& Arrays.AreEqual(this.mData, other.mData);
-		}
+            int length = thisContents.Length;
+            if (thatContents.Length != length)
+                return false;
+            if (length == 1)
+                return true;
+
+            int last = length - 1;
+            for (int i = 0; i < last; ++i)
+            {
+                if (thisContents[i] != thatContents[i])
+                    return false;
+            }
+
+            int padBits = thisContents[0];
+            byte thisLastOctetDer = (byte)(thisContents[last] & (0xFF << padBits));
+            byte thatLastOctetDer = (byte)(thatContents[last] & (0xFF << padBits));
+
+            return thisLastOctetDer == thatLastOctetDer;
+        }
 
-		public override string GetString()
+        public override string GetString()
 		{
 			StringBuilder buffer = new StringBuilder("#");
 
@@ -254,27 +292,41 @@ namespace Org.BouncyCastle.Asn1
 			return buffer.ToString();
 		}
 
-		internal static DerBitString FromAsn1Octets(byte[] octets)
-		{
-	        if (octets.Length < 1)
-	            throw new ArgumentException("truncated BIT STRING detected", "octets");
+        internal static int EncodedLength(bool withID, int contentsLength)
+        {
+            return Asn1OutputStream.GetLengthOfEncodingDL(withID, contentsLength);
+        }
 
-            int padBits = octets[0];
-            byte[] data = Arrays.CopyOfRange(octets, 1, octets.Length);
+        internal static void Encode(Asn1OutputStream asn1Out, bool withID, byte[] buf, int off, int len)
+        {
+            asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, buf, off, len);
+        }
+
+        internal static void Encode(Asn1OutputStream asn1Out, bool withID, byte pad, byte[] buf, int off, int len)
+        {
+            asn1Out.WriteEncodingDL(withID, Asn1Tags.BitString, pad, buf, off, len);
+        }
+
+		internal static DerBitString CreatePrimitive(byte[] contents)
+		{
+            int length = contents.Length;
+            if (length < 1)
+                throw new ArgumentException("truncated BIT STRING detected", "contents");
 
-            if (padBits > 0 && padBits < 8 && data.Length > 0)
+            int padBits = contents[0];
+            if (padBits > 0)
             {
-                int last = data[data.Length - 1];
-                int mask = (1 << padBits) - 1;
+                if (padBits > 7 || length < 2)
+                    throw new ArgumentException("invalid pad bits detected", "contents");
 
-                if ((last & mask) != 0)
+                byte finalOctet = contents[length - 1];
+                if (finalOctet != (byte)(finalOctet & (0xFF << padBits)))
                 {
-                    return new BerBitString(data, padBits);
+                    return new BerBitString(contents, false);
                 }
             }
 
-            return new DerBitString(data, padBits);
+            return new DerBitString(contents, false);
 		}
 	}
 }
-