changeset 8816:bc316611d450

8068427: Hashtable deserialization reconstitutes table with wrong capacity Reviewed-by: mduigou, martin, chegar, dfuchs
author plevart
date Thu, 23 Nov 2017 20:15:51 +0000
parents 23e15121a4a6
children 512cfc54ab90
files src/share/classes/java/util/Hashtable.java test/java/util/Hashtable/DeserializedLength.java
diffstat 2 files changed, 133 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/src/share/classes/java/util/Hashtable.java	Thu Nov 23 19:50:37 2017 +0000
+++ b/src/share/classes/java/util/Hashtable.java	Thu Nov 23 20:15:51 2017 +0000
@@ -935,10 +935,10 @@
         Entry<K, V> entryStack = null;
 
         synchronized (this) {
-            // Write out the length, threshold, loadfactor
+            // Write out the threshold and loadFactor
             s.defaultWriteObject();
 
-            // Write out length, count of elements
+            // Write out the length and count of elements
             s.writeInt(table.length);
             s.writeInt(count);
 
@@ -968,22 +968,33 @@
     private void readObject(java.io.ObjectInputStream s)
          throws IOException, ClassNotFoundException
     {
-        // Read in the length, threshold, and loadfactor
+        // Read in the threshold and loadFactor
         s.defaultReadObject();
 
+        // Validate loadFactor (ignore threshold - it will be re-computed)
+        if (loadFactor <= 0 || Float.isNaN(loadFactor))
+            throw new StreamCorruptedException("Illegal Load: " + loadFactor);
+
         // Read the original length of the array and number of elements
         int origlength = s.readInt();
         int elements = s.readInt();
 
-        // Compute new size with a bit of room 5% to grow but
-        // no larger than the original size.  Make the length
+        // Validate # of elements
+        if (elements < 0)
+            throw new StreamCorruptedException("Illegal # of Elements: " + elements);
+
+        // Clamp original length to be more than elements / loadFactor
+        // (this is the invariant enforced with auto-growth)
+        origlength = Math.max(origlength, (int)(elements / loadFactor) + 1);
+
+        // Compute new length with a bit of room 5% + 3 to grow but
+        // no larger than the clamped original length.  Make the length
         // odd if it's large enough, this helps distribute the entries.
         // Guard against the length ending up zero, that's not valid.
-        int length = (int)(elements * loadFactor) + (elements / 20) + 3;
+        int length = (int)((elements + elements / 20) / loadFactor) + 3;
         if (length > elements && (length & 1) == 0)
             length--;
-        if (origlength > 0 && length > origlength)
-            length = origlength;
+        length = Math.min(length, origlength);
 
         Entry<K,V>[] newTable = new Entry[length];
         threshold = (int) Math.min(length * loadFactor, MAX_ARRAY_SIZE + 1);
@@ -994,7 +1005,7 @@
         for (; elements > 0; elements--) {
             K key = (K)s.readObject();
             V value = (V)s.readObject();
-            // synch could be eliminated for performance
+            // sync is eliminated for performance
             reconstitutionPut(newTable, key, value);
         }
         this.table = newTable;
@@ -1007,9 +1018,9 @@
      *
      * <p>This differs from the regular put method in several ways. No
      * checking for rehashing is necessary since the number of elements
-     * initially in the table is known. The modCount is not incremented
-     * because we are creating a new instance. Also, no return value
-     * is needed.
+     * initially in the table is known. The modCount is not incremented and
+     * there's no synchronization because we are creating a new instance.
+     * Also, no return value is needed.
      */
     private void reconstitutionPut(Entry<K,V>[] tab, K key, V value)
         throws StreamCorruptedException
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/java/util/Hashtable/DeserializedLength.java	Thu Nov 23 20:15:51 2017 +0000
@@ -0,0 +1,110 @@
+/*
+ * Copyright (c) 2014, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+import java.io.*;
+import java.lang.reflect.Field;
+import java.util.Hashtable;
+
+/**
+ * @test
+ * @bug 8068427
+ * @summary Hashtable deserialization reconstitutes table with wrong capacity
+ */
+public class DeserializedLength {
+
+    static boolean testDeserializedLength(int elements, float loadFactor) throws Exception {
+
+        // construct Hashtable with minimal initial capacity and given loadFactor
+        Hashtable<Integer, Integer> ht1 = new Hashtable<>(1, loadFactor);
+
+        // add given number of unique elements
+        for (int i = 0; i < elements; i++) {
+            ht1.put(i, i);
+        }
+
+        // serialize and deserialize into a deep clone
+        Hashtable<Integer, Integer> ht2 = serialClone(ht1);
+
+        // compare lengths of internal tables
+        Object[] table1 = (Object[]) hashtableTableField.get(ht1);
+        Object[] table2 = (Object[]) hashtableTableField.get(ht2);
+        assert table1 != null;
+        assert table2 != null;
+
+        int minLength = (int) (ht1.size() / loadFactor) + 1;
+        int maxLength = minLength * 2;
+
+        boolean ok = (table2.length >= minLength && table2.length <= maxLength);
+
+        System.out.printf(
+            "%7d %5.2f %7d %7d %7d...%7d %s\n",
+            ht1.size(), loadFactor,
+            table1.length, table2.length,
+            minLength, maxLength,
+            (ok ? "OK" : "NOT-OK")
+        );
+
+        return ok;
+    }
+
+    static <T> T serialClone(T o) throws IOException, ClassNotFoundException {
+        ByteArrayOutputStream bos = new ByteArrayOutputStream();
+        try (ObjectOutputStream oos = new ObjectOutputStream(bos)) {
+            oos.writeObject(o);
+        }
+        @SuppressWarnings("unchecked")
+        T clone = (T) new ObjectInputStream(
+            new ByteArrayInputStream(bos.toByteArray())).readObject();
+        return clone;
+    }
+
+    private static final Field hashtableTableField;
+
+    static {
+        try {
+            hashtableTableField = Hashtable.class.getDeclaredField("table");
+            hashtableTableField.setAccessible(true);
+        } catch (NoSuchFieldException e) {
+            throw new Error(e);
+        }
+    }
+
+    public static void main(String[] args) throws Exception {
+        boolean ok = true;
+
+        System.out.printf("Results:\n" +
+                "                 ser.  deser.\n" +
+                "   size  load  lentgh  length       valid range ok?\n" +
+                "------- ----- ------- ------- ----------------- ------\n"
+        );
+
+        for (int elements : new int[]{10, 50, 500, 5000}) {
+            for (float loadFactor : new float[]{0.15f, 0.5f, 0.75f, 1.0f, 2.5f}) {
+                ok &= testDeserializedLength(elements, loadFactor);
+            }
+        }
+        if (!ok) {
+            throw new AssertionError("Test failed.");
+        }
+    }
+}