Mercurial > hg > release > thermostat-1.2
changeset 1532:8dd9bcaee77e
RequestQueue should refuse invalid Requests
Reviewed-by: jerboaa, vanaltj
Review-thread: http://icedtea.classpath.org/pipermail/thermostat/2014-October/011334.html
author | Omair Majid <omajid@redhat.com> |
---|---|
date | Tue, 28 Oct 2014 17:00:23 -0400 |
parents | cbbf264033b4 |
children | d5a7cb4fab04 |
files | agent/command/src/main/java/com/redhat/thermostat/agent/command/internal/ServerHandler.java agent/command/src/test/java/com/redhat/thermostat/agent/command/internal/ServerHandlerTest.java client/command/src/main/java/com/redhat/thermostat/client/command/internal/RequestQueueImpl.java client/command/src/test/java/com/redhat/thermostat/client/command/internal/RequestQueueImplTest.java |
diffstat | 4 files changed, 129 insertions(+), 18 deletions(-) [+] |
line wrap: on
line diff
--- a/agent/command/src/main/java/com/redhat/thermostat/agent/command/internal/ServerHandler.java Tue Oct 28 14:12:32 2014 -0400 +++ b/agent/command/src/main/java/com/redhat/thermostat/agent/command/internal/ServerHandler.java Tue Oct 28 17:00:23 2014 -0400 @@ -56,6 +56,7 @@ import com.redhat.thermostat.agent.command.ReceiverRegistry; import com.redhat.thermostat.agent.command.RequestReceiver; +import com.redhat.thermostat.common.command.Message.MessageType; import com.redhat.thermostat.common.command.Request; import com.redhat.thermostat.common.command.Response; import com.redhat.thermostat.common.command.Response.ResponseType; @@ -70,8 +71,15 @@ private static final Logger logger = LoggingUtils.getLogger(ServerHandler.class); private ReceiverRegistry receivers; private SSLConfiguration sslConf; + private StorageGetter storageGetter; public ServerHandler(ReceiverRegistry receivers, SSLConfiguration sslConf) { + this(receivers, sslConf, new StorageGetter()); + } + + /** For testing only */ + ServerHandler(ReceiverRegistry receivers, SSLConfiguration sslConf, StorageGetter getter) { + this.storageGetter = getter; this.receivers = receivers; this.sslConf = sslConf; } @@ -103,17 +111,24 @@ @Override public void messageReceived(ChannelHandlerContext ctx, MessageEvent e) { Request request = (Request) e.getMessage(); + String receiverName = request.getReceiver(); + MessageType requestType = request.getType(); + logger.info("Request received: '" + requestType + "' for '" + receiverName + "'"); boolean authSucceeded = authenticateRequestIfNecessary(request); Response response = null; if (! authSucceeded) { + logger.info("Authentication for request failed"); response = new Response(ResponseType.AUTH_FAILED); } else { - String receiverName = request.getReceiver(); - logger.info("Request received: " + request.getType().toString() + " for " + receiverName); - RequestReceiver receiver = receivers.getReceiver(receiverName); - if (receiver != null) { - response = receiver.receive(request); - } else { + if (receiverName != null && requestType != null) { + RequestReceiver receiver = receivers.getReceiver(receiverName); + if (receiver != null) { + response = receiver.receive(request); + } + } + + if (response == null) { + logger.info("Receiver with name '" + receiverName + "' not found "); response = new Response(ResponseType.ERROR); } } @@ -128,9 +143,7 @@ } private boolean authenticateRequestIfNecessary(Request request) { - BundleContext bCtx = FrameworkUtil.getBundle(getClass()).getBundleContext(); - ServiceReference storageRef = bCtx.getServiceReference(Storage.class.getName()); - Storage storage = (Storage) bCtx.getService(storageRef); + Storage storage = storageGetter.get(); if (storage instanceof SecureStorage) { boolean authenticatedRequest = authenticateRequest(request, (SecureStorage) storage); if (authenticatedRequest) { @@ -182,5 +195,16 @@ } } } + + /** for testing only */ + static class StorageGetter { + public Storage get() { + BundleContext bCtx = FrameworkUtil.getBundle(getClass()).getBundleContext(); + ServiceReference<Storage> storageRef = bCtx.getServiceReference(Storage.class); + // FIXME there should be a matching unget() somewhere to release the reference + Storage storage = (Storage) bCtx.getService(storageRef); + return storage; + } + } }
--- a/agent/command/src/test/java/com/redhat/thermostat/agent/command/internal/ServerHandlerTest.java Tue Oct 28 14:12:32 2014 -0400 +++ b/agent/command/src/test/java/com/redhat/thermostat/agent/command/internal/ServerHandlerTest.java Tue Oct 28 17:00:23 2014 -0400 @@ -36,38 +36,80 @@ package com.redhat.thermostat.agent.command.internal; +import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.isA; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import org.jboss.netty.channel.Channel; import org.jboss.netty.channel.ChannelFuture; import org.jboss.netty.channel.ChannelHandlerContext; import org.jboss.netty.channel.ChannelPipeline; +import org.jboss.netty.channel.MessageEvent; import org.jboss.netty.handler.ssl.SslHandler; +import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; import com.redhat.thermostat.agent.command.internal.ServerHandler.SSLHandshakeDoneListener; +import com.redhat.thermostat.agent.command.internal.ServerHandler.StorageGetter; +import com.redhat.thermostat.common.command.Request; +import com.redhat.thermostat.common.command.Response; +import com.redhat.thermostat.common.command.Response.ResponseType; import com.redhat.thermostat.shared.config.SSLConfiguration; public class ServerHandlerTest { + private StorageGetter storageGetter; + + private Channel channel; + private ChannelHandlerContext ctx; + + @Before + public void setup() { + channel = mock(Channel.class); + when(channel.isConnected()).thenReturn(true); + ChannelFuture channelFuture = mock(ChannelFuture.class); + when(channel.write(isA(Response.class))).thenReturn(channelFuture); + + ctx = mock(ChannelHandlerContext.class); + when(ctx.getChannel()).thenReturn(channel); + + storageGetter = mock(StorageGetter.class); + } + @Test public void channelConnectedAddsSSLListener() throws Exception { SSLConfiguration mockSSLConf = mock(SSLConfiguration.class); when(mockSSLConf.enableForCmdChannel()).thenReturn(true); - ServerHandler handler = new ServerHandler(null, mockSSLConf); + ServerHandler handler = new ServerHandler(null, mockSSLConf, storageGetter); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); ChannelPipeline pipeline = mock(ChannelPipeline.class); when(ctx.getPipeline()).thenReturn(pipeline); SslHandler sslHandler = mock(SslHandler.class); when(pipeline.get(SslHandler.class)).thenReturn(sslHandler); - ChannelFuture future = mock(ChannelFuture.class); - when(sslHandler.handshake()).thenReturn(future); + ChannelFuture handshakeFuture = mock(ChannelFuture.class); + when(sslHandler.handshake()).thenReturn(handshakeFuture); handler.channelConnected(ctx, null); - verify(future).addListener(any(SSLHandshakeDoneListener.class)); + verify(handshakeFuture).addListener(any(SSLHandshakeDoneListener.class)); + } + + @Test + public void invalidRequestReturnsAnErrorResponse() { + // target and receiver are null + Request request = mock(Request.class); + MessageEvent event = mock(MessageEvent.class); + when(event.getMessage()).thenReturn(request); + + ServerHandler handler = new ServerHandler(null, null, storageGetter); + handler.messageReceived(ctx, event); + + ArgumentCaptor<Response> responseCaptor = ArgumentCaptor.forClass(Response.class); + verify(channel).write(responseCaptor.capture()); + assertEquals(ResponseType.ERROR, responseCaptor.getValue().getType()); } }
--- a/client/command/src/main/java/com/redhat/thermostat/client/command/internal/RequestQueueImpl.java Tue Oct 28 14:12:32 2014 -0400 +++ b/client/command/src/main/java/com/redhat/thermostat/client/command/internal/RequestQueueImpl.java Tue Oct 28 17:00:23 2014 -0400 @@ -41,7 +41,6 @@ import java.util.logging.Level; import java.util.logging.Logger; - import org.apache.commons.codec.binary.Base64; import org.jboss.netty.bootstrap.ClientBootstrap; import org.jboss.netty.channel.Channel; @@ -79,12 +78,23 @@ @Override public void putRequest(Request request) { + assertValidRequest(request); + // Only enqueue request if we've successfully authenticated if (authenticateRequest(request)) { queue.add(request); } } + private void assertValidRequest(Request request) { + if (request.getReceiver() == null) { + throw new AssertionError("The receiver for a Request must not be null"); + } + if (request.getTarget() == null) { + throw new AssertionError("The target for a Request must not be null"); + } + } + private boolean authenticateRequest(Request request) { boolean result = true; // Successful by default, unless storage is secure BundleContext bCtx = FrameworkUtil.getBundle(getClass()).getBundleContext();
--- a/client/command/src/test/java/com/redhat/thermostat/client/command/internal/RequestQueueImplTest.java Tue Oct 28 14:12:32 2014 -0400 +++ b/client/command/src/test/java/com/redhat/thermostat/client/command/internal/RequestQueueImplTest.java Tue Oct 28 17:00:23 2014 -0400 @@ -39,12 +39,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import java.net.InetSocketAddress; import java.util.Arrays; import org.junit.Before; @@ -86,6 +88,32 @@ when(FrameworkUtil.getBundle(RequestQueueImpl.class)).thenReturn(mockBundle); } + @Test + public void putRequestRejectsRequestWithNullTarget() { + Request request = createRequest(null, "foobar"); + ConfigurationRequestContext ctx = mock(ConfigurationRequestContext.class); + RequestQueueImpl queue = new RequestQueueImpl(ctx); + try { + queue.putRequest(request); + fail("expected assertion not thrown"); + } catch (AssertionError e) { + // okay + } + } + + @Test + public void putRequestRejectsRequestWithNullReceiver() { + Request request = createRequest(mock(InetSocketAddress.class), null); + ConfigurationRequestContext ctx = mock(ConfigurationRequestContext.class); + RequestQueueImpl queue = new RequestQueueImpl(ctx); + try { + queue.putRequest(request); + fail("expected assertion not thrown"); + } catch (AssertionError e) { + // okay + } + } + /* * Other tests ensure that secure storage is returned from storage providers. * This is an attempt to make sure that authentication hooks are actually @@ -102,7 +130,7 @@ ConfigurationRequestContext ctx = mock(ConfigurationRequestContext.class); RequestQueueImpl queue = new RequestQueueImpl(ctx); - Request request = mock(Request.class); + Request request = createRequest(mock(InetSocketAddress.class), ""); queue.putRequest(request); verify(request).setParameter(eq(Request.CLIENT_TOKEN), any(String.class)); verify(request).setParameter(eq(Request.AUTH_TOKEN), any(String.class)); @@ -118,7 +146,7 @@ ConfigurationRequestContext ctx = mock(ConfigurationRequestContext.class); RequestQueueImpl queue = new RequestQueueImpl(ctx); - Request request = mock(Request.class); + Request request = createRequest(mock(InetSocketAddress.class), ""); RequestResponseListener listener = mock(RequestResponseListener.class); when(request.getListeners()).thenReturn(Arrays.asList(listener)); @@ -137,9 +165,16 @@ ConfigurationRequestContext ctx = mock(ConfigurationRequestContext.class); RequestQueueImpl queue = new RequestQueueImpl(ctx); - Request request = mock(Request.class); + Request request = createRequest(mock(InetSocketAddress.class), ""); queue.putRequest(request); assertTrue(queue.getQueue().contains(request)); } + + private static Request createRequest(InetSocketAddress agentAddress, String receiver) { + Request request = mock(Request.class); + when(request.getReceiver()).thenReturn(receiver); + when(request.getTarget()).thenReturn(agentAddress); + return request; + } }