changeset 17271:ed0eab5829b2

8163958: Improved garbage collection Reviewed-by: smarks, chegar, skoivu, rhalade
author rriggs
date Thu, 16 Mar 2017 16:16:31 -0400
parents 092e0cea6d40
children 8645b4aed22f
files make/rmic/Rmic-java.rmi.gmk src/java.base/share/classes/java/util/Vector.java src/java.rmi/share/classes/sun/rmi/server/UnicastRef.java src/java.rmi/share/classes/sun/rmi/server/UnicastServerRef.java src/java.rmi/share/classes/sun/rmi/transport/ConnectionInputStream.java src/java.rmi/share/classes/sun/rmi/transport/DGCClient.java src/java.rmi/share/classes/sun/rmi/transport/DGCImpl_Skel.java src/java.rmi/share/classes/sun/rmi/transport/DGCImpl_Stub.java src/java.rmi/share/classes/sun/rmi/transport/StreamRemoteCall.java test/java/rmi/testlibrary/TestSocketFactory.java
diffstat 10 files changed, 936 insertions(+), 19 deletions(-) [+]
line wrap: on
line diff
--- a/make/rmic/Rmic-java.rmi.gmk	Tue Mar 14 19:15:42 2017 -0700
+++ b/make/rmic/Rmic-java.rmi.gmk	Thu Mar 16 16:16:31 2017 -0400
@@ -41,8 +41,7 @@
 GENCLASSES += $(RMI_12)
 
 $(eval $(call SetupRMICompilation,RMI_11, \
-    CLASSES := sun.rmi.registry.RegistryImpl \
-        sun.rmi.transport.DGCImpl, \
+    CLASSES := sun.rmi.registry.RegistryImpl, \
     CLASSES_DIR := $(CLASSES_DIR)/java.rmi, \
     STUB_CLASSES_DIR := $(STUB_CLASSES_DIR)/java.rmi, \
     RUN_V11 := true))
--- a/src/java.base/share/classes/java/util/Vector.java	Tue Mar 14 19:15:42 2017 -0700
+++ b/src/java.base/share/classes/java/util/Vector.java	Thu Mar 16 16:16:31 2017 -0400
@@ -154,7 +154,7 @@
 
     /**
      * Constructs an empty vector so that its internal data array
-     * has size {@code 10} and its standard capacity increment is
+     * has size {@code 10} and its standard capacBasity increment is
      * zero.
      */
     public Vector() {
--- a/src/java.rmi/share/classes/sun/rmi/server/UnicastRef.java	Tue Mar 14 19:15:42 2017 -0700
+++ b/src/java.rmi/share/classes/sun/rmi/server/UnicastRef.java	Thu Mar 16 16:16:31 2017 -0400
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2017, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -32,7 +32,6 @@
 import java.rmi.MarshalException;
 import java.rmi.Remote;
 import java.rmi.RemoteException;
-import java.rmi.ServerException;
 import java.rmi.UnmarshalException;
 import java.rmi.server.Operation;
 import java.rmi.server.RemoteCall;
@@ -187,14 +186,11 @@
 
                 return returnValue;
 
-            } catch (IOException e) {
+            } catch (IOException | ClassNotFoundException e) {
+                // disable saving any refs in the inputStream for GC
+                ((StreamRemoteCall)call).discardPendingRefs();
                 clientRefLog.log(Log.BRIEF,
-                                 "IOException unmarshalling return: ", e);
-                throw new UnmarshalException("error unmarshalling return", e);
-            } catch (ClassNotFoundException e) {
-                clientRefLog.log(Log.BRIEF,
-                    "ClassNotFoundException unmarshalling return: ", e);
-
+                                 e.getClass().getName() + " unmarshalling return: ", e);
                 throw new UnmarshalException("error unmarshalling return", e);
             } finally {
                 try {
--- a/src/java.rmi/share/classes/sun/rmi/server/UnicastServerRef.java	Tue Mar 14 19:15:42 2017 -0700
+++ b/src/java.rmi/share/classes/sun/rmi/server/UnicastServerRef.java	Thu Mar 16 16:16:31 2017 -0400
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2016, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2017, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -57,6 +57,7 @@
 import java.util.concurrent.atomic.AtomicInteger;
 import sun.rmi.runtime.Log;
 import sun.rmi.transport.LiveRef;
+import sun.rmi.transport.StreamRemoteCall;
 import sun.rmi.transport.Target;
 import sun.rmi.transport.tcp.TCPTransport;
 
@@ -328,10 +329,9 @@
             try {
                 unmarshalCustomCallData(in);
                 params = unmarshalParameters(obj, method, marshalStream);
-            } catch (java.io.IOException e) {
-                throw new UnmarshalException(
-                    "error unmarshalling arguments", e);
-            } catch (ClassNotFoundException e) {
+            } catch (java.io.IOException | ClassNotFoundException e) {
+                // disable saving any refs in the inputStream for GC
+                ((StreamRemoteCall) call).discardPendingRefs();
                 throw new UnmarshalException(
                     "error unmarshalling arguments", e);
             } finally {
--- a/src/java.rmi/share/classes/sun/rmi/transport/ConnectionInputStream.java	Tue Mar 14 19:15:42 2017 -0700
+++ b/src/java.rmi/share/classes/sun/rmi/transport/ConnectionInputStream.java	Thu Mar 16 16:16:31 2017 -0400
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2012, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2017, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -82,6 +82,14 @@
     }
 
     /**
+     * Discard the saved incoming refs so there is nothing to register
+     * when {@code registerRefs} is called.
+     */
+    void discardRefs() {
+        incomingRefTable.clear();
+    }
+
+    /**
      * Add references to DGC table (and possibly send dirty call).
      * RegisterRefs now calls DGCClient.referenced on all
      * refs with the same endpoint at once to achieve batching of
--- a/src/java.rmi/share/classes/sun/rmi/transport/DGCClient.java	Tue Mar 14 19:15:42 2017 -0700
+++ b/src/java.rmi/share/classes/sun/rmi/transport/DGCClient.java	Thu Mar 16 16:16:31 2017 -0400
@@ -24,9 +24,11 @@
  */
 package sun.rmi.transport;
 
+import java.io.InvalidClassException;
 import java.lang.ref.PhantomReference;
 import java.lang.ref.ReferenceQueue;
 import java.net.SocketPermission;
+import java.rmi.UnmarshalException;
 import java.security.AccessController;
 import java.security.PrivilegedAction;
 import java.util.HashMap;
@@ -41,6 +43,8 @@
 import java.rmi.dgc.Lease;
 import java.rmi.dgc.VMID;
 import java.rmi.server.ObjID;
+
+import sun.rmi.runtime.Log;
 import sun.rmi.runtime.NewThreadAction;
 import sun.rmi.server.UnicastRef;
 import sun.rmi.server.Util;
@@ -388,6 +392,12 @@
                 synchronized (this) {
                     dirtyFailures++;
 
+                    if (e instanceof UnmarshalException
+                            && e.getCause() instanceof InvalidClassException) {
+                        DGCImpl.dgcLog.log(Log.BRIEF, "InvalidClassException exception in DGC dirty call", e);
+                        return;             // protocol error, do not register these refs
+                    }
+
                     if (dirtyFailures == 1) {
                         /*
                          * If this was the first recent failed dirty call,
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/java.rmi/share/classes/sun/rmi/transport/DGCImpl_Skel.java	Thu Mar 16 16:16:31 2017 -0400
@@ -0,0 +1,112 @@
+/*
+ * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.  Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package sun.rmi.transport;
+
+/**
+ * Skeleton to dispatch DGC methods.
+ * Originally generated by RMIC but frozen to match the stubs.
+ */
+@SuppressWarnings({"deprecation", "serial"})
+public final class DGCImpl_Skel
+        implements java.rmi.server.Skeleton {
+    private static final java.rmi.server.Operation[] operations = {
+            new java.rmi.server.Operation("void clean(java.rmi.server.ObjID[], long, java.rmi.dgc.VMID, boolean)"),
+            new java.rmi.server.Operation("java.rmi.dgc.Lease dirty(java.rmi.server.ObjID[], long, java.rmi.dgc.Lease)")
+    };
+
+    private static final long interfaceHash = -669196253586618813L;
+
+    public java.rmi.server.Operation[] getOperations() {
+        return operations.clone();
+    }
+
+    public void dispatch(java.rmi.Remote obj, java.rmi.server.RemoteCall call, int opnum, long hash)
+            throws java.lang.Exception {
+        if (hash != interfaceHash)
+            throw new java.rmi.server.SkeletonMismatchException("interface hash mismatch");
+
+        sun.rmi.transport.DGCImpl server = (sun.rmi.transport.DGCImpl) obj;
+        switch (opnum) {
+            case 0: // clean(ObjID[], long, VMID, boolean)
+            {
+                java.rmi.server.ObjID[] $param_arrayOf_ObjID_1;
+                long $param_long_2;
+                java.rmi.dgc.VMID $param_VMID_3;
+                boolean $param_boolean_4;
+                try {
+                    java.io.ObjectInput in = call.getInputStream();
+                    $param_arrayOf_ObjID_1 = (java.rmi.server.ObjID[]) in.readObject();
+                    $param_long_2 = in.readLong();
+                    $param_VMID_3 = (java.rmi.dgc.VMID) in.readObject();
+                    $param_boolean_4 = in.readBoolean();
+                } catch (java.io.IOException e) {
+                    throw new java.rmi.UnmarshalException("error unmarshalling arguments", e);
+                } catch (java.lang.ClassNotFoundException e) {
+                    throw new java.rmi.UnmarshalException("error unmarshalling arguments", e);
+                } finally {
+                    call.releaseInputStream();
+                }
+                server.clean($param_arrayOf_ObjID_1, $param_long_2, $param_VMID_3, $param_boolean_4);
+                try {
+                    call.getResultStream(true);
+                } catch (java.io.IOException e) {
+                    throw new java.rmi.MarshalException("error marshalling return", e);
+                }
+                break;
+            }
+
+            case 1: // dirty(ObjID[], long, Lease)
+            {
+                java.rmi.server.ObjID[] $param_arrayOf_ObjID_1;
+                long $param_long_2;
+                java.rmi.dgc.Lease $param_Lease_3;
+                try {
+                    java.io.ObjectInput in = call.getInputStream();
+                    $param_arrayOf_ObjID_1 = (java.rmi.server.ObjID[]) in.readObject();
+                    $param_long_2 = in.readLong();
+                    $param_Lease_3 = (java.rmi.dgc.Lease) in.readObject();
+                } catch (java.io.IOException e) {
+                    throw new java.rmi.UnmarshalException("error unmarshalling arguments", e);
+                } catch (java.lang.ClassNotFoundException e) {
+                    throw new java.rmi.UnmarshalException("error unmarshalling arguments", e);
+                } finally {
+                    call.releaseInputStream();
+                }
+                java.rmi.dgc.Lease $result = server.dirty($param_arrayOf_ObjID_1, $param_long_2, $param_Lease_3);
+                try {
+                    java.io.ObjectOutput out = call.getResultStream(true);
+                    out.writeObject($result);
+                } catch (java.io.IOException e) {
+                    throw new java.rmi.MarshalException("error marshalling return", e);
+                }
+                break;
+            }
+
+            default:
+                throw new java.rmi.UnmarshalException("invalid method number");
+        }
+    }
+}
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/java.rmi/share/classes/sun/rmi/transport/DGCImpl_Stub.java	Thu Mar 16 16:16:31 2017 -0400
@@ -0,0 +1,183 @@
+/*
+ * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.  Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package sun.rmi.transport;
+
+import java.io.ObjectInputFilter;
+import java.io.ObjectInputStream;
+import java.rmi.dgc.Lease;
+import java.rmi.dgc.VMID;
+import java.rmi.server.UID;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+
+import sun.rmi.server.UnicastRef;
+import sun.rmi.transport.tcp.TCPConnection;
+
+/**
+ * Stubs to invoke DGC remote methods.
+ * Originally generated from RMIC but frozen to insert serialFilter.
+ */
+@SuppressWarnings({"deprecation", "serial"})
+public final class DGCImpl_Stub
+        extends java.rmi.server.RemoteStub
+        implements java.rmi.dgc.DGC {
+    private static final java.rmi.server.Operation[] operations = {
+            new java.rmi.server.Operation("void clean(java.rmi.server.ObjID[], long, java.rmi.dgc.VMID, boolean)"),
+            new java.rmi.server.Operation("java.rmi.dgc.Lease dirty(java.rmi.server.ObjID[], long, java.rmi.dgc.Lease)")
+    };
+
+    private static final long interfaceHash = -669196253586618813L;
+
+    /** Registry max depth of remote invocations. **/
+    private static int DGCCLIENT_MAX_DEPTH = 6;
+
+    /** Registry maximum array size in remote invocations. **/
+    private static int DGCCLIENT_MAX_ARRAY_SIZE = 10000;
+
+    // constructors
+    public DGCImpl_Stub() {
+        super();
+    }
+
+    public DGCImpl_Stub(java.rmi.server.RemoteRef ref) {
+        super(ref);
+    }
+
+    // methods from remote interfaces
+
+    // implementation of clean(ObjID[], long, VMID, boolean)
+    public void clean(java.rmi.server.ObjID[] $param_arrayOf_ObjID_1, long $param_long_2, java.rmi.dgc.VMID $param_VMID_3, boolean $param_boolean_4)
+            throws java.rmi.RemoteException {
+        try {
+            java.rmi.server.RemoteCall call = ref.newCall((java.rmi.server.RemoteObject) this, operations, 0, interfaceHash);
+            try {
+                java.io.ObjectOutput out = call.getOutputStream();
+                out.writeObject($param_arrayOf_ObjID_1);
+                out.writeLong($param_long_2);
+                out.writeObject($param_VMID_3);
+                out.writeBoolean($param_boolean_4);
+            } catch (java.io.IOException e) {
+                throw new java.rmi.MarshalException("error marshalling arguments", e);
+            }
+            ref.invoke(call);
+            ref.done(call);
+        } catch (java.lang.RuntimeException e) {
+            throw e;
+        } catch (java.rmi.RemoteException e) {
+            throw e;
+        } catch (java.lang.Exception e) {
+            throw new java.rmi.UnexpectedException("undeclared checked exception", e);
+        }
+    }
+
+    // implementation of dirty(ObjID[], long, Lease)
+    public java.rmi.dgc.Lease dirty(java.rmi.server.ObjID[] $param_arrayOf_ObjID_1, long $param_long_2, java.rmi.dgc.Lease $param_Lease_3)
+            throws java.rmi.RemoteException {
+        try {
+            java.rmi.server.RemoteCall call = ref.newCall((java.rmi.server.RemoteObject) this, operations, 1, interfaceHash);
+            try {
+                java.io.ObjectOutput out = call.getOutputStream();
+                out.writeObject($param_arrayOf_ObjID_1);
+                out.writeLong($param_long_2);
+                out.writeObject($param_Lease_3);
+            } catch (java.io.IOException e) {
+                throw new java.rmi.MarshalException("error marshalling arguments", e);
+            }
+            ref.invoke(call);
+            java.rmi.dgc.Lease $result;
+            Connection connection = ((StreamRemoteCall) call).getConnection();
+            try {
+                java.io.ObjectInput in = call.getInputStream();
+
+                if (in instanceof ObjectInputStream) {
+                    /**
+                     * Set a filter on the stream for the return value.
+                     */
+                    ObjectInputStream ois = (ObjectInputStream) in;
+                    AccessController.doPrivileged((PrivilegedAction<Void>)() -> {
+                        ois.setObjectInputFilter(DGCImpl_Stub::leaseFilter);
+                        return null;
+                    });
+                }
+                $result = (java.rmi.dgc.Lease) in.readObject();
+            } catch (java.io.IOException | java.lang.ClassNotFoundException e) {
+                if (connection instanceof TCPConnection) {
+                    // Modified to prevent re-use of the connection after an exception
+                    ((TCPConnection) connection).getChannel().free(connection, false);
+                }
+                throw new java.rmi.UnmarshalException("error unmarshalling return", e);
+            } finally {
+                ref.done(call);
+            }
+            return $result;
+        } catch (java.lang.RuntimeException e) {
+            throw e;
+        } catch (java.rmi.RemoteException e) {
+            throw e;
+        } catch (java.lang.Exception e) {
+            throw new java.rmi.UnexpectedException("undeclared checked exception", e);
+        }
+    }
+
+    /**
+     * ObjectInputFilter to filter DGCClient return value (a Lease).
+     * The list of acceptable classes is very short and explicit.
+     * The depth and array sizes are limited.
+     *
+     * @param filterInfo access to class, arrayLength, etc.
+     * @return  {@link ObjectInputFilter.Status#ALLOWED} if allowed,
+     *          {@link ObjectInputFilter.Status#REJECTED} if rejected,
+     *          otherwise {@link ObjectInputFilter.Status#UNDECIDED}
+     */
+    private static ObjectInputFilter.Status leaseFilter(ObjectInputFilter.FilterInfo filterInfo) {
+
+        if (filterInfo.depth() > DGCCLIENT_MAX_DEPTH) {
+            return ObjectInputFilter.Status.REJECTED;
+        }
+        Class<?> clazz = filterInfo.serialClass();
+        if (clazz != null) {
+            while (clazz.isArray()) {
+                if (filterInfo.arrayLength() >= 0 && filterInfo.arrayLength() > DGCCLIENT_MAX_ARRAY_SIZE) {
+                    return ObjectInputFilter.Status.REJECTED;
+                }
+                // Arrays are allowed depending on the component type
+                clazz = clazz.getComponentType();
+            }
+            if (clazz.isPrimitive()) {
+                // Arrays of primitives are allowed
+                return ObjectInputFilter.Status.ALLOWED;
+            }
+            return (clazz == UID.class ||
+                    clazz == VMID.class ||
+                    clazz == Lease.class)
+                    ? ObjectInputFilter.Status.ALLOWED
+                    : ObjectInputFilter.Status.REJECTED;
+        }
+        // Not a class, not size limited
+        return ObjectInputFilter.Status.UNDECIDED;
+    }
+
+}
--- a/src/java.rmi/share/classes/sun/rmi/transport/StreamRemoteCall.java	Tue Mar 14 19:15:42 2017 -0700
+++ b/src/java.rmi/share/classes/sun/rmi/transport/StreamRemoteCall.java	Thu Mar 16 16:16:31 2017 -0400
@@ -1,5 +1,5 @@
 /*
- * Copyright (c) 1996, 2013, Oracle and/or its affiliates. All rights reserved.
+ * Copyright (c) 1996, 2017, Oracle and/or its affiliates. All rights reserved.
  * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
  *
  * This code is free software; you can redistribute it and/or modify it
@@ -168,6 +168,13 @@
     }
 
     /**
+     * Discard any post-processing of refs the InputStream.
+     */
+    public void discardPendingRefs() {
+        in.discardRefs();
+    }
+
+    /**
      * Returns an output stream (may put out header information
      * relating to the success of the call).
      * @param success If true, indicates normal return, else indicates
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/test/java/rmi/testlibrary/TestSocketFactory.java	Thu Mar 16 16:16:31 2017 -0400
@@ -0,0 +1,602 @@
+/*
+ * Copyright (c) 2017, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+import java.io.ByteArrayOutputStream;
+import java.io.FilterInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import java.io.Serializable;
+import java.net.InetAddress;
+import java.net.ServerSocket;
+import java.net.Socket;
+import java.net.SocketAddress;
+import java.net.SocketException;
+import java.net.SocketOption;
+import java.nio.channels.ServerSocketChannel;
+import java.nio.channels.SocketChannel;
+import java.rmi.server.RMIClientSocketFactory;
+import java.rmi.server.RMIServerSocketFactory;
+import java.rmi.server.RMISocketFactory;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+
+import org.testng.Assert;
+import org.testng.TestNG;
+import org.testng.annotations.Test;
+import org.testng.annotations.DataProvider;
+
+
+/**
+ * A RMISocketFactory utility factory to log RMI stream contents and to
+ * match and replace output stream contents to simulate failures.
+ */
+public class TestSocketFactory extends RMISocketFactory
+        implements RMIClientSocketFactory, RMIServerSocketFactory, Serializable {
+
+    private static final long serialVersionUID = 1L;
+
+    private volatile transient byte[] matchBytes;
+
+    private volatile transient byte[] replaceBytes;
+
+    private transient final List<InterposeSocket> sockets = new ArrayList<>();
+
+    private transient final List<InterposeServerSocket> serverSockets = new ArrayList<>();
+
+    public static final boolean DEBUG = false;
+
+    /**
+     * Debugging output can be synchronized with logging of RMI actions.
+     *
+     * @param format a printf format
+     * @param args   any args
+     */
+    private static void DEBUG(String format, Object... args) {
+        if (DEBUG) {
+            System.err.printf(format, args);
+        }
+    }
+
+    /**
+     * Create a socket factory that creates InputStreams that log
+     * and OutputStreams that log .
+     */
+    public TestSocketFactory() {
+        this.matchBytes = new byte[0];
+        this.replaceBytes = this.matchBytes;
+        System.out.printf("Creating TestSocketFactory()%n");
+    }
+
+    public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) {
+        this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes");
+        this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes");
+        sockets.forEach( s -> s.setMatchReplaceBytes(matchBytes, replaceBytes));
+        serverSockets.forEach( s -> s.setMatchReplaceBytes(matchBytes, replaceBytes));
+
+    }
+
+    @Override
+    public Socket createSocket(String host, int port) throws IOException {
+        Socket socket = RMISocketFactory.getDefaultSocketFactory()
+                .createSocket(host, port);
+        InterposeSocket s = new InterposeSocket(socket, matchBytes, replaceBytes);
+        sockets.add(s);
+        return s;
+    }
+
+    /**
+     * Return the current list of sockets.
+     * @return Return a snapshot of the current list of sockets
+     */
+    public List<InterposeSocket> getSockets() {
+        List<InterposeSocket> snap = new ArrayList<>(sockets);
+        return snap;
+    }
+
+    @Override
+    public ServerSocket createServerSocket(int port) throws IOException {
+
+        ServerSocket serverSocket = RMISocketFactory.getDefaultSocketFactory()
+                .createServerSocket(port);
+        InterposeServerSocket ss = new InterposeServerSocket(serverSocket, matchBytes, replaceBytes);
+        serverSockets.add(ss);
+        return ss;
+    }
+
+    /**
+     * Return the current list of server sockets.
+     * @return Return a snapshot of the current list of server sockets
+     */
+    public List<InterposeServerSocket> getServerSockets() {
+        List<InterposeServerSocket> snap = new ArrayList<>(serverSockets);
+        return snap;
+    }
+
+    /**
+     * An InterposeSocket wraps a socket that produces InputStreams
+     * and OutputStreams that log the traffic.
+     * The OutputStreams it produces match an array of bytes and replace them.
+     * Useful for injecting protocol and content errors.
+     */
+    public static class InterposeSocket extends Socket {
+        private final Socket socket;
+        private InputStream in;
+        private MatchReplaceOutputStream out;
+        private volatile byte[] matchBytes;
+        private volatile byte[] replaceBytes;
+        private final ByteArrayOutputStream inLogStream;
+        private final ByteArrayOutputStream outLogStream;
+        private final String name;
+        private static volatile int num = 0;    // index for created InterposeSockets
+
+        public InterposeSocket(Socket socket, byte[] matchBytes, byte[] replaceBytes) {
+            this.socket = socket;
+            this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes");
+            this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes");
+            this.inLogStream = new ByteArrayOutputStream();
+            this.outLogStream = new ByteArrayOutputStream();
+            this.name = "IS" + ++num + "::"
+                    + Thread.currentThread().getName() + ": "
+                    + socket.getLocalPort() + " <  " + socket.getPort();
+        }
+
+        public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) {
+            this.matchBytes = matchBytes;
+            this.replaceBytes = replaceBytes;
+            out.setMatchReplaceBytes(matchBytes, replaceBytes);
+        }
+
+        @Override
+        public void connect(SocketAddress endpoint) throws IOException {
+            socket.connect(endpoint);
+        }
+
+        @Override
+        public void connect(SocketAddress endpoint, int timeout) throws IOException {
+            socket.connect(endpoint, timeout);
+        }
+
+        @Override
+        public void bind(SocketAddress bindpoint) throws IOException {
+            socket.bind(bindpoint);
+        }
+
+        @Override
+        public InetAddress getInetAddress() {
+            return socket.getInetAddress();
+        }
+
+        @Override
+        public InetAddress getLocalAddress() {
+            return socket.getLocalAddress();
+        }
+
+        @Override
+        public int getPort() {
+            return socket.getPort();
+        }
+
+        @Override
+        public int getLocalPort() {
+            return socket.getLocalPort();
+        }
+
+        @Override
+        public SocketAddress getRemoteSocketAddress() {
+            return socket.getRemoteSocketAddress();
+        }
+
+        @Override
+        public SocketAddress getLocalSocketAddress() {
+            return socket.getLocalSocketAddress();
+        }
+
+        @Override
+        public SocketChannel getChannel() {
+            return socket.getChannel();
+        }
+
+        @Override
+        public synchronized void close() throws IOException {
+            socket.close();
+        }
+
+        @Override
+        public String toString() {
+            return "InterposeSocket " + name + ": " + socket.toString();
+        }
+
+        @Override
+        public boolean isConnected() {
+            return socket.isConnected();
+        }
+
+        @Override
+        public boolean isBound() {
+            return socket.isBound();
+        }
+
+        @Override
+        public boolean isClosed() {
+            return socket.isClosed();
+        }
+
+        @Override
+        public <T> Socket setOption(SocketOption<T> name, T value) throws IOException {
+            return socket.setOption(name, value);
+        }
+
+        @Override
+        public <T> T getOption(SocketOption<T> name) throws IOException {
+            return socket.getOption(name);
+        }
+
+        @Override
+        public Set<SocketOption<?>> supportedOptions() {
+            return socket.supportedOptions();
+        }
+
+        @Override
+        public synchronized InputStream getInputStream() throws IOException {
+            if (in == null) {
+                in = socket.getInputStream();
+                String name = Thread.currentThread().getName() + ": "
+                        + socket.getLocalPort() + " <  " + socket.getPort();
+                in = new LoggingInputStream(in, name, inLogStream);
+                DEBUG("Created new InterposeInputStream: %s%n", name);
+            }
+            return in;
+        }
+
+        @Override
+        public synchronized OutputStream getOutputStream() throws IOException {
+            if (out == null) {
+                OutputStream o = socket.getOutputStream();
+                String name = Thread.currentThread().getName() + ": "
+                        + socket.getLocalPort() + "  > " + socket.getPort();
+                out = new MatchReplaceOutputStream(o, name, outLogStream, matchBytes, replaceBytes);
+                DEBUG("Created new MatchReplaceOutputStream: %s%n", name);
+            }
+            return out;
+        }
+
+        /**
+         * Return the bytes logged from the input stream.
+         * @return Return the bytes logged from the input stream.
+         */
+        public byte[] getInLogBytes() {
+            return inLogStream.toByteArray();
+        }
+
+        /**
+         * Return the bytes logged from the output stream.
+         * @return Return the bytes logged from the output stream.
+         */
+        public byte[] getOutLogBytes() {
+            return outLogStream.toByteArray();
+        }
+
+    }
+
+    /**
+     * InterposeServerSocket is a ServerSocket that wraps each Socket it accepts
+     * with an InterposeSocket so that its input and output streams can be monitored.
+     */
+    public static class InterposeServerSocket extends ServerSocket {
+        private final ServerSocket socket;
+        private volatile byte[] matchBytes;
+        private volatile byte[] replaceBytes;
+        private final List<InterposeSocket> sockets = new ArrayList<>();
+
+        public InterposeServerSocket(ServerSocket socket, byte[] matchBytes, byte[] replaceBytes) throws IOException {
+            this.socket = socket;
+            this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes");
+            this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes");
+        }
+
+        public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) {
+            this.matchBytes = matchBytes;
+            this.replaceBytes = replaceBytes;
+            sockets.forEach(s -> s.setMatchReplaceBytes(matchBytes, replaceBytes));
+        }
+        /**
+         * Return a snapshot of the current list of sockets created from this server socket.
+         * @return Return a snapshot of the current list of sockets
+         */
+        public List<InterposeSocket> getSockets() {
+            List<InterposeSocket> snap = new ArrayList<>(sockets);
+            return snap;
+        }
+
+        @Override
+        public void bind(SocketAddress endpoint) throws IOException {
+            socket.bind(endpoint);
+        }
+
+        @Override
+        public void bind(SocketAddress endpoint, int backlog) throws IOException {
+            socket.bind(endpoint, backlog);
+        }
+
+        @Override
+        public InetAddress getInetAddress() {
+            return socket.getInetAddress();
+        }
+
+        @Override
+        public int getLocalPort() {
+            return socket.getLocalPort();
+        }
+
+        @Override
+        public SocketAddress getLocalSocketAddress() {
+            return socket.getLocalSocketAddress();
+        }
+
+        @Override
+        public Socket accept() throws IOException {
+            Socket s = socket.accept();
+            InterposeSocket socket = new InterposeSocket(s, matchBytes, replaceBytes);
+            sockets.add(socket);
+            return socket;
+        }
+
+        @Override
+        public void close() throws IOException {
+            socket.close();
+        }
+
+        @Override
+        public ServerSocketChannel getChannel() {
+            return socket.getChannel();
+        }
+
+        @Override
+        public boolean isClosed() {
+            return socket.isClosed();
+        }
+
+        @Override
+        public String toString() {
+            return socket.toString();
+        }
+
+        @Override
+        public <T> ServerSocket setOption(SocketOption<T> name, T value) throws IOException {
+            return socket.setOption(name, value);
+        }
+
+        @Override
+        public <T> T getOption(SocketOption<T> name) throws IOException {
+            return socket.getOption(name);
+        }
+
+        @Override
+        public Set<SocketOption<?>> supportedOptions() {
+            return socket.supportedOptions();
+        }
+
+        @Override
+        public synchronized void setSoTimeout(int timeout) throws SocketException {
+            socket.setSoTimeout(timeout);
+        }
+
+        @Override
+        public synchronized int getSoTimeout() throws IOException {
+            return socket.getSoTimeout();
+        }
+    }
+
+    /**
+     * LoggingInputStream is a stream and logs all bytes read to it.
+     * For identification it is given a name.
+     */
+    public static class LoggingInputStream extends FilterInputStream {
+        private int bytesIn = 0;
+        private final String name;
+        private final OutputStream log;
+
+        public LoggingInputStream(InputStream in, String name, OutputStream log) {
+            super(in);
+            this.name = name;
+            this.log = log;
+        }
+
+        @Override
+        public int read() throws IOException {
+            int b = super.read();
+            if (b >= 0) {
+                log.write(b);
+                bytesIn++;
+            }
+            return b;
+        }
+
+        @Override
+        public int read(byte[] b, int off, int len) throws IOException {
+            int bytes = super.read(b, off, len);
+            if (bytes > 0) {
+                log.write(b, off, bytes);
+                bytesIn += bytes;
+            }
+            return bytes;
+        }
+
+        @Override
+        public int read(byte[] b) throws IOException {
+            return read(b, 0, b.length);
+        }
+
+        @Override
+        public void close() throws IOException {
+            super.close();
+        }
+
+        @Override
+        public String toString() {
+            return String.format("%s: In: (%d)", name, bytesIn);
+        }
+    }
+
+    /**
+     * An OutputStream that replaces one string of bytes with another.
+     * If any range matches, the match starts after the partial match.
+     */
+    static class MatchReplaceOutputStream extends OutputStream {
+        private final OutputStream out;
+        private final String name;
+        private volatile byte[] matchBytes;
+        private volatile byte[] replaceBytes;
+        int matchIndex;
+        private int bytesOut = 0;
+        private final OutputStream log;
+
+        MatchReplaceOutputStream(OutputStream out, String name, OutputStream log,
+                                 byte[] matchBytes, byte[] replaceBytes) {
+            this.out = out;
+            this.name = name;
+            this.matchBytes = Objects.requireNonNull(matchBytes, "matchBytes");
+            this.replaceBytes = Objects.requireNonNull(replaceBytes, "replaceBytes");
+            matchIndex = 0;
+            this.log = log;
+        }
+
+        public void setMatchReplaceBytes(byte[] matchBytes, byte[] replaceBytes) {
+            this.matchBytes = matchBytes;
+            this.replaceBytes = replaceBytes;
+            matchIndex = 0;
+        }
+
+
+        public void write(int b) throws IOException {
+            b = b & 0xff;
+            if (matchBytes.length == 0) {
+                out.write(b);
+                log.write(b);
+                bytesOut++;
+                return;
+            }
+            if (b == (matchBytes[matchIndex] & 0xff)) {
+                if (++matchIndex >= matchBytes.length) {
+                    matchIndex = 0;
+                    DEBUG( "TestSocketFactory MatchReplace %s replaced %d bytes at offset: %d (x%04x)%n",
+                            name, replaceBytes.length, bytesOut, bytesOut);
+                    out.write(replaceBytes);
+                    log.write(replaceBytes);
+                    bytesOut += replaceBytes.length;
+                }
+            } else {
+                if (matchIndex > 0) {
+                    // mismatch, write out any that matched already
+                    if (matchIndex > 0) // Only non-trivial matches
+                        DEBUG( "Partial match %s matched %d bytes at offset: %d (0x%04x), expected: x%02x, actual: x%02x%n",
+                                name, matchIndex, bytesOut, bytesOut,  matchBytes[matchIndex], b);
+                    out.write(matchBytes, 0, matchIndex);
+                    log.write(matchBytes, 0, matchIndex);
+                    bytesOut += matchIndex;
+                    matchIndex = 0;
+                }
+                if (b == (matchBytes[matchIndex] & 0xff)) {
+                    matchIndex++;
+                } else {
+                    out.write(b);
+                    log.write(b);
+                    bytesOut++;
+                }
+            }
+        }
+
+        @Override
+        public String toString() {
+            return String.format("%s: Out: (%d)", name, bytesOut);
+        }
+    }
+
+    private static byte[] orig = new byte[]{
+            (byte) 0x80, 0x05,
+            0x73, 0x72, 0x00, 0x12, // TC_OBJECT, TC_CLASSDESC, length = 18
+            0x6A, 0x61, 0x76, 0x61, 0x2E, 0x72, 0x6D, 0x69, 0x2E, // "java.rmi."
+            0x64, 0x67, 0x63, 0x2E, 0x4C, 0x65, 0x61, 0x73, 0x65  // "dgc.Lease"
+    };
+    private static byte[] repl = new byte[]{
+            (byte) 0x80, 0x05,
+            0x73, 0x72, 0x00, 0x12, // TC_OBJECT, TC_CLASSDESC, length = 18
+            0x6A, 0x61, 0x76, 0x61, 0x2E, (byte) 'l', (byte) 'a', (byte) 'n', (byte) 'g',
+            0x2E, (byte) 'R', (byte) 'u', (byte) 'n', (byte) 'n', (byte) 'a', (byte) 'b', (byte) 'l',
+            (byte) 'e'
+    };
+
+    @DataProvider(name = "MatchReplaceData")
+    static Object[][] matchReplaceData() {
+        byte[] empty = new byte[0];
+        byte[] byte1 = new byte[]{1, 2, 3, 4, 5, 6};
+        byte[] bytes2 = new byte[]{1, 2, 4, 3, 5, 6};
+        byte[] bytes3 = new byte[]{6, 5, 4, 3, 2, 1};
+        byte[] bytes4 = new byte[]{1, 2, 0x10, 0x20, 0x30, 0x40, 5, 6};
+        byte[] bytes4a = new byte[]{1, 2, 0x10, 0x20, 0x30, 0x40, 5, 7};  // mostly matches bytes4
+        byte[] bytes5 = new byte[]{0x30, 0x40, 5, 6};
+        byte[] bytes6 = new byte[]{1, 2, 0x10, 0x20, 0x30};
+
+        return new Object[][]{
+                {new byte[]{}, new byte[]{}, empty, empty},
+                {new byte[]{}, new byte[]{}, byte1, byte1},
+                {new byte[]{3, 4}, new byte[]{4, 3}, byte1, bytes2}, //swap bytes
+                {new byte[]{3, 4}, new byte[]{0x10, 0x20, 0x30, 0x40}, byte1, bytes4}, // insert
+                {new byte[]{1, 2, 0x10, 0x20}, new byte[]{}, bytes4, bytes5}, // delete head
+                {new byte[]{0x40, 5, 6}, new byte[]{}, bytes4, bytes6},   // delete tail
+                {new byte[]{0x40, 0x50}, new byte[]{0x60, 0x50}, bytes4, bytes4}, // partial match, replace nothing
+                {bytes4a, bytes3, bytes4, bytes4}, // long partial match, not replaced
+                {orig, repl, orig, repl},
+        };
+    }
+
+    @Test(enabled = true, dataProvider = "MatchReplaceData")
+    static void test3(byte[] match, byte[] replace,
+                      byte[] input, byte[] expected) {
+        System.out.printf("match: %s, replace: %s%n", Arrays.toString(match), Arrays.toString(replace));
+        try (ByteArrayOutputStream output = new ByteArrayOutputStream();
+        ByteArrayOutputStream log = new ByteArrayOutputStream();
+             OutputStream out = new MatchReplaceOutputStream(output, "test3",
+                     log, match, replace)) {
+            out.write(input);
+            byte[] actual = output.toByteArray();
+            long index = Arrays.mismatch(actual, expected);
+
+            if (index >= 0) {
+                System.out.printf("array mismatch, offset: %d%n", index);
+                System.out.printf("actual: %s%n", Arrays.toString(actual));
+                System.out.printf("expected: %s%n", Arrays.toString(expected));
+            }
+            Assert.assertEquals(actual, expected, "match/replace fail");
+        } catch (IOException ioe) {
+            Assert.fail("unexpected exception", ioe);
+        }
+    }
+
+
+
+}