changeset 1532:8dd9bcaee77e

RequestQueue should refuse invalid Requests Reviewed-by: jerboaa, vanaltj Review-thread: http://icedtea.classpath.org/pipermail/thermostat/2014-October/011334.html
author Omair Majid <omajid@redhat.com>
date Tue, 28 Oct 2014 17:00:23 -0400
parents cbbf264033b4
children d5a7cb4fab04
files agent/command/src/main/java/com/redhat/thermostat/agent/command/internal/ServerHandler.java agent/command/src/test/java/com/redhat/thermostat/agent/command/internal/ServerHandlerTest.java client/command/src/main/java/com/redhat/thermostat/client/command/internal/RequestQueueImpl.java client/command/src/test/java/com/redhat/thermostat/client/command/internal/RequestQueueImplTest.java
diffstat 4 files changed, 129 insertions(+), 18 deletions(-) [+]
line wrap: on
line diff
--- a/agent/command/src/main/java/com/redhat/thermostat/agent/command/internal/ServerHandler.java	Tue Oct 28 14:12:32 2014 -0400
+++ b/agent/command/src/main/java/com/redhat/thermostat/agent/command/internal/ServerHandler.java	Tue Oct 28 17:00:23 2014 -0400
@@ -56,6 +56,7 @@
 
 import com.redhat.thermostat.agent.command.ReceiverRegistry;
 import com.redhat.thermostat.agent.command.RequestReceiver;
+import com.redhat.thermostat.common.command.Message.MessageType;
 import com.redhat.thermostat.common.command.Request;
 import com.redhat.thermostat.common.command.Response;
 import com.redhat.thermostat.common.command.Response.ResponseType;
@@ -70,8 +71,15 @@
     private static final Logger logger = LoggingUtils.getLogger(ServerHandler.class);
     private ReceiverRegistry receivers;
     private SSLConfiguration sslConf;
+    private StorageGetter storageGetter;
 
     public ServerHandler(ReceiverRegistry receivers, SSLConfiguration sslConf) {
+        this(receivers, sslConf, new StorageGetter());
+    }
+
+    /** For testing only */
+    ServerHandler(ReceiverRegistry receivers, SSLConfiguration sslConf, StorageGetter getter) {
+        this.storageGetter = getter;
         this.receivers = receivers;
         this.sslConf = sslConf;
     }
@@ -103,17 +111,24 @@
     @Override
     public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) {
         Request request = (Request) e.getMessage();
+        String receiverName = request.getReceiver();
+        MessageType requestType = request.getType();
+        logger.info("Request received: '" + requestType + "' for '" + receiverName + "'");
         boolean authSucceeded = authenticateRequestIfNecessary(request);
         Response response = null;
         if (! authSucceeded) {
+            logger.info("Authentication for request failed");
             response = new Response(ResponseType.AUTH_FAILED);
         } else {
-            String receiverName = request.getReceiver();
-            logger.info("Request received: " + request.getType().toString() + " for " + receiverName);
-            RequestReceiver receiver = receivers.getReceiver(receiverName);
-            if (receiver != null) {
-                response = receiver.receive(request);
-            } else {
+            if (receiverName != null && requestType != null) {
+                RequestReceiver receiver = receivers.getReceiver(receiverName);
+                if (receiver != null) {
+                    response = receiver.receive(request);
+                }
+            }
+
+            if (response == null) {
+                logger.info("Receiver with name '" + receiverName + "' not found ");
                 response = new Response(ResponseType.ERROR);
             }
         }
@@ -128,9 +143,7 @@
     }
 
     private boolean authenticateRequestIfNecessary(Request request) {
-        BundleContext bCtx = FrameworkUtil.getBundle(getClass()).getBundleContext();
-        ServiceReference storageRef = bCtx.getServiceReference(Storage.class.getName());
-        Storage storage = (Storage) bCtx.getService(storageRef);
+        Storage storage = storageGetter.get();
         if (storage instanceof SecureStorage) {
             boolean authenticatedRequest = authenticateRequest(request, (SecureStorage) storage);
             if (authenticatedRequest) {
@@ -182,5 +195,16 @@
             }
         }
     }
+
+    /** for testing only */
+    static class StorageGetter {
+        public Storage get() {
+            BundleContext bCtx = FrameworkUtil.getBundle(getClass()).getBundleContext();
+            ServiceReference<Storage> storageRef = bCtx.getServiceReference(Storage.class);
+            // FIXME there should be a matching unget() somewhere to release the reference
+            Storage storage = (Storage) bCtx.getService(storageRef);
+            return storage;
+        }
+    }
 }
 
--- a/agent/command/src/test/java/com/redhat/thermostat/agent/command/internal/ServerHandlerTest.java	Tue Oct 28 14:12:32 2014 -0400
+++ b/agent/command/src/test/java/com/redhat/thermostat/agent/command/internal/ServerHandlerTest.java	Tue Oct 28 17:00:23 2014 -0400
@@ -36,38 +36,80 @@
 
 package com.redhat.thermostat.agent.command.internal;
 
+import static org.junit.Assert.assertEquals;
 import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.isA;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import org.jboss.netty.channel.Channel;
 import org.jboss.netty.channel.ChannelFuture;
 import org.jboss.netty.channel.ChannelHandlerContext;
 import org.jboss.netty.channel.ChannelPipeline;
+import org.jboss.netty.channel.MessageEvent;
 import org.jboss.netty.handler.ssl.SslHandler;
+import org.junit.Before;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 
 import com.redhat.thermostat.agent.command.internal.ServerHandler.SSLHandshakeDoneListener;
+import com.redhat.thermostat.agent.command.internal.ServerHandler.StorageGetter;
+import com.redhat.thermostat.common.command.Request;
+import com.redhat.thermostat.common.command.Response;
+import com.redhat.thermostat.common.command.Response.ResponseType;
 import com.redhat.thermostat.shared.config.SSLConfiguration;
 
 public class ServerHandlerTest {
 
+    private StorageGetter storageGetter;
+
+    private Channel channel;
+    private ChannelHandlerContext ctx;
+
+    @Before
+    public void setup() {
+        channel = mock(Channel.class);
+        when(channel.isConnected()).thenReturn(true);
+        ChannelFuture channelFuture = mock(ChannelFuture.class);
+        when(channel.write(isA(Response.class))).thenReturn(channelFuture);
+
+        ctx = mock(ChannelHandlerContext.class);
+        when(ctx.getChannel()).thenReturn(channel);
+
+        storageGetter = mock(StorageGetter.class);
+    }
+
     @Test
     public void channelConnectedAddsSSLListener() throws Exception {
         SSLConfiguration mockSSLConf = mock(SSLConfiguration.class);
         when(mockSSLConf.enableForCmdChannel()).thenReturn(true);
-        ServerHandler handler = new ServerHandler(null, mockSSLConf);
+        ServerHandler handler = new ServerHandler(null, mockSSLConf, storageGetter);
         
-        ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
         ChannelPipeline pipeline = mock(ChannelPipeline.class);
         when(ctx.getPipeline()).thenReturn(pipeline);
         SslHandler sslHandler = mock(SslHandler.class);
         when(pipeline.get(SslHandler.class)).thenReturn(sslHandler);
-        ChannelFuture future = mock(ChannelFuture.class);
-        when(sslHandler.handshake()).thenReturn(future);
+        ChannelFuture handshakeFuture = mock(ChannelFuture.class);
+        when(sslHandler.handshake()).thenReturn(handshakeFuture);
         
         handler.channelConnected(ctx, null);
-        verify(future).addListener(any(SSLHandshakeDoneListener.class));
+        verify(handshakeFuture).addListener(any(SSLHandshakeDoneListener.class));
+    }
+
+    @Test
+    public void invalidRequestReturnsAnErrorResponse() {
+        // target and receiver are null
+        Request request = mock(Request.class);
+        MessageEvent event = mock(MessageEvent.class);
+        when(event.getMessage()).thenReturn(request);
+
+        ServerHandler handler = new ServerHandler(null, null, storageGetter);
+        handler.messageReceived(ctx, event);
+
+        ArgumentCaptor<Response> responseCaptor = ArgumentCaptor.forClass(Response.class);
+        verify(channel).write(responseCaptor.capture());
+        assertEquals(ResponseType.ERROR, responseCaptor.getValue().getType());
     }
 }
 
--- a/client/command/src/main/java/com/redhat/thermostat/client/command/internal/RequestQueueImpl.java	Tue Oct 28 14:12:32 2014 -0400
+++ b/client/command/src/main/java/com/redhat/thermostat/client/command/internal/RequestQueueImpl.java	Tue Oct 28 17:00:23 2014 -0400
@@ -41,7 +41,6 @@
 import java.util.logging.Level;
 import java.util.logging.Logger;
 
-
 import org.apache.commons.codec.binary.Base64;
 import org.jboss.netty.bootstrap.ClientBootstrap;
 import org.jboss.netty.channel.Channel;
@@ -79,12 +78,23 @@
 
     @Override
     public void putRequest(Request request) {
+        assertValidRequest(request);
+
         // Only enqueue request if we've successfully authenticated
         if (authenticateRequest(request)) {
             queue.add(request);
         }
     }
 
+    private void assertValidRequest(Request request) {
+        if (request.getReceiver() == null) {
+            throw new AssertionError("The receiver for a Request must not be null");
+        }
+        if (request.getTarget() == null) {
+            throw new AssertionError("The target for a Request must not be null");
+        }
+    }
+
     private boolean authenticateRequest(Request request) {
         boolean result = true; // Successful by default, unless storage is secure
         BundleContext bCtx = FrameworkUtil.getBundle(getClass()).getBundleContext();
--- a/client/command/src/test/java/com/redhat/thermostat/client/command/internal/RequestQueueImplTest.java	Tue Oct 28 14:12:32 2014 -0400
+++ b/client/command/src/test/java/com/redhat/thermostat/client/command/internal/RequestQueueImplTest.java	Tue Oct 28 17:00:23 2014 -0400
@@ -39,12 +39,14 @@
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
+import java.net.InetSocketAddress;
 import java.util.Arrays;
 
 import org.junit.Before;
@@ -86,6 +88,32 @@
         when(FrameworkUtil.getBundle(RequestQueueImpl.class)).thenReturn(mockBundle);
     }
 
+    @Test
+    public void putRequestRejectsRequestWithNullTarget() {
+        Request request = createRequest(null, "foobar");
+        ConfigurationRequestContext ctx = mock(ConfigurationRequestContext.class);
+        RequestQueueImpl queue = new RequestQueueImpl(ctx);
+        try {
+            queue.putRequest(request);
+            fail("expected assertion not thrown");
+        } catch (AssertionError e) {
+            // okay
+        }
+    }
+
+    @Test
+    public void putRequestRejectsRequestWithNullReceiver() {
+        Request request = createRequest(mock(InetSocketAddress.class), null);
+        ConfigurationRequestContext ctx = mock(ConfigurationRequestContext.class);
+        RequestQueueImpl queue = new RequestQueueImpl(ctx);
+        try {
+            queue.putRequest(request);
+            fail("expected assertion not thrown");
+        } catch (AssertionError e) {
+            // okay
+        }
+    }
+
     /*
      * Other tests ensure that secure storage is returned from storage providers.
      * This is an attempt to make sure that authentication hooks are actually
@@ -102,7 +130,7 @@
         
         ConfigurationRequestContext ctx = mock(ConfigurationRequestContext.class);
         RequestQueueImpl queue = new RequestQueueImpl(ctx);
-        Request request = mock(Request.class);
+        Request request = createRequest(mock(InetSocketAddress.class), "");
         queue.putRequest(request);
         verify(request).setParameter(eq(Request.CLIENT_TOKEN), any(String.class));
         verify(request).setParameter(eq(Request.AUTH_TOKEN), any(String.class));
@@ -118,7 +146,7 @@
         
         ConfigurationRequestContext ctx = mock(ConfigurationRequestContext.class);
         RequestQueueImpl queue = new RequestQueueImpl(ctx);
-        Request request = mock(Request.class);
+        Request request = createRequest(mock(InetSocketAddress.class), "");
         
         RequestResponseListener listener = mock(RequestResponseListener.class);
         when(request.getListeners()).thenReturn(Arrays.asList(listener));
@@ -137,9 +165,16 @@
         
         ConfigurationRequestContext ctx = mock(ConfigurationRequestContext.class);
         RequestQueueImpl queue = new RequestQueueImpl(ctx);
-        Request request = mock(Request.class);
+        Request request = createRequest(mock(InetSocketAddress.class), "");
         queue.putRequest(request);
         assertTrue(queue.getQueue().contains(request));
     }
+
+    private static Request createRequest(InetSocketAddress agentAddress, String receiver) {
+        Request request = mock(Request.class);
+        when(request.getReceiver()).thenReturn(receiver);
+        when(request.getTarget()).thenReturn(agentAddress);
+        return request;
+    }
 }