changeset 2730:08bb838d3c6c

[commands] Determine receiver based on action. Reviewed-by: neugens Review-thread: http://icedtea.classpath.org/pipermail/thermostat/2017-July/024339.html
author Severin Gehwolf <sgehwolf@redhat.com>
date Fri, 28 Jul 2017 12:16:49 +0200
parents 9cbc2bcf39c2
children 406fe8f9d8bf
files plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/receiver/PingReceiver.java plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/socket/AgentSocketOnMessageCallback.java plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/typeadapters/AgentRequestTypeAdapter.java plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/typeadapters/BasicMessageTypeAdapter.java plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/receiver/ReceiverRegistry.java plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/model/AgentRequest.java plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/model/Message.java plugins/commands/agent/src/test/java/com/redhat/thermostat/commands/agent/internal/receiver/PingReceiverTest.java plugins/commands/agent/src/test/java/com/redhat/thermostat/commands/agent/internal/socket/AgentSocketOnMessageCallbackTest.java plugins/commands/agent/src/test/java/com/redhat/thermostat/commands/agent/receiver/ReceiverRegistryTest.java plugins/killvm/agent/src/main/java/com/redhat/thermostat/killvm/agent/internal/KillVmReceiver.java plugins/killvm/agent/src/test/java/com/redhat/thermostat/killvm/agent/internal/KillVmReceiverTest.java
diffstat 12 files changed, 175 insertions(+), 110 deletions(-) [+]
line wrap: on
line diff
--- a/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/receiver/PingReceiver.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/receiver/PingReceiver.java	Fri Jul 28 12:16:49 2017 +0200
@@ -52,9 +52,10 @@
 
 @Component
 @Service(value = RequestReceiver.class)
-@Property(name = "servicename", value = "com.redhat.thermostat.commands.agent.internal.receiver.PingReceiver")
+@Property(name = "servicename", value = PingReceiver.ACTION_NAME)
 public class PingReceiver implements RequestReceiver {
 
+    public static final String ACTION_NAME = "ping";
     private static Logger logger = LoggingUtils.getLogger(PingReceiver.class);
     
     @Activate
@@ -64,6 +65,11 @@
     
     @Override
     public WebSocketResponse receive(AgentRequest request) {
+        // Sanity check. We should never get requests outside our action domain.
+        if (!ACTION_NAME.equals(request.getAction())) {
+            logger.severe("Received action '" + request.getAction() + "' for receiver '" + ACTION_NAME + "'");
+            return new WebSocketResponse(request.getSequenceId(), ResponseType.ERROR);
+        }
         return new WebSocketResponse(request.getSequenceId(), ResponseType.OK);
     }
 
--- a/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/socket/AgentSocketOnMessageCallback.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/socket/AgentSocketOnMessageCallback.java	Fri Jul 28 12:16:49 2017 +0200
@@ -107,14 +107,7 @@
 
         @Override
         public void run() {
-            String receiverName = request
-                    .getParam(AgentRequest.RECEIVER_PARAM_NAME);
-            if (receiverName == null) {
-                String msg = "No receiver specified in cmd-channel request with sequence "
-                        + request.getSequenceId();
-                handleError(msg);
-                return;
-            }
+            String receiverName = request.getAction();
             RequestReceiver receiver = receivers.getReceiver(receiverName);
             if (receiver == null) {
                 String msg = "Got cmd-channel request for receiver '"
--- a/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/typeadapters/AgentRequestTypeAdapter.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/typeadapters/AgentRequestTypeAdapter.java	Fri Jul 28 12:16:49 2017 +0200
@@ -55,6 +55,12 @@
     @Override
     public void write(JsonWriter out, AgentRequest request) throws IOException {
         JsonObject object = getEnvelopeWithTypeAndSequence(request, request.getSequenceId());
+        JsonElement actionElem = gson.toJsonTree(request.getAction());
+        object.add(Message.ACTION_KEY, actionElem);
+        JsonElement jvmIdElem = gson.toJsonTree(request.getJvmId());
+        object.add(Message.JVM_ID_KEY, jvmIdElem);
+        JsonElement systemIdElem = gson.toJsonTree(request.getSystemId());
+        object.add(Message.SYSTEM_ID_KEY, systemIdElem);
         JsonElement parms = gson.toJsonTree(request.getParams());
         object.add(Message.PAYLOAD_KEY, parms);
         gson.toJson(object, out);
--- a/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/typeadapters/BasicMessageTypeAdapter.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/internal/typeadapters/BasicMessageTypeAdapter.java	Fri Jul 28 12:16:49 2017 +0200
@@ -51,8 +51,8 @@
 import com.redhat.thermostat.commands.model.AgentRequest;
 import com.redhat.thermostat.commands.model.ClientRequest;
 import com.redhat.thermostat.commands.model.Message;
+import com.redhat.thermostat.commands.model.Message.MessageType;
 import com.redhat.thermostat.commands.model.WebSocketResponse;
-import com.redhat.thermostat.commands.model.Message.MessageType;
 import com.redhat.thermostat.commands.model.WebSocketResponse.ResponseType;
 
 abstract class BasicMessageTypeAdapter<T extends Message> extends TypeAdapter<T> {
@@ -86,18 +86,14 @@
         RawMessage raw = getRawMessageFromReader(in);
         switch (raw.getMessageType()) {
         case AGENT_REQUEST: {
-            requireSequenceNonNull(raw.getSequenceElement(),
-                                   "Agent request without a sequence!");
-            long sequence = raw.getSequenceElement().getAsLong();
-            SortedMap<String, String> params = decodePayloadAsParamMap(raw.getPayloadElement());
-            return new AgentRequest(sequence, params);
+            return decodeAgentRequest(raw);
         }
         case CLIENT_REQUEST: {
             SortedMap<String, String> params = decodePayloadAsParamMap(raw.getPayloadElement());
             return new ClientRequest(params);
         }
         case RESPONSE: {
-            requireSequenceNonNull(raw.getSequenceElement(),
+            requireElementNonNull(raw.getSequenceElement(),
                                    "Response message without a sequence!");
             long sequence = raw.getSequenceElement().getAsLong();
             ResponseType respType = decodePayloadAsResponseType(raw.getPayloadElement());
@@ -108,6 +104,23 @@
         }
     }
 
+    private Message decodeAgentRequest(RawMessage raw) {
+        requireElementNonNull(raw.getSequenceElement(),
+                               "Agent request without a sequence!");
+        long sequence = raw.getSequenceElement().getAsLong();
+        requireElementNonNull(raw.getActionElement(),
+                              "Agent request without action!");
+        String action = raw.getActionElement().getAsString();
+        requireElementNonNull(raw.getJvmIdElement(),
+                              "Agent request without jvmId!");
+        String jvmId = raw.getJvmIdElement().getAsString();
+        requireElementNonNull(raw.getSystemIdElement(),
+                "Agent request without systemId!");
+        String systemId = raw.getSystemIdElement().getAsString();
+        SortedMap<String, String> params = decodePayloadAsParamMap(raw.getPayloadElement());
+        return new AgentRequest(sequence, action, systemId, jvmId, params);
+    }
+
     private SortedMap<String, String> decodePayloadAsParamMap(JsonObject payloadElem) {
         SortedMap<String, String> paramMap = new TreeMap<>();
         if (payloadElem == null) {
@@ -133,8 +146,8 @@
         return ResponseType.valueOf(typeStr);
     }
 
-    private void requireSequenceNonNull(JsonElement sequenceElem, String msg) throws IllegalStateException {
-        if (sequenceElem == null) {
+    private void requireElementNonNull(JsonElement elem, String msg) throws IllegalStateException {
+        if (elem == null) {
             throw new IllegalStateException(msg);
         }
     }
@@ -151,6 +164,9 @@
         private final JsonElement typeElem;
         private final JsonElement sequenceElem;
         private final JsonObject payloadElem;
+        private final JsonElement actionElem;
+        private final JsonElement jvmIdElem;
+        private final JsonElement systemIdElem;
 
         private RawMessage(JsonObject object) {
             Objects.requireNonNull(object);
@@ -158,6 +174,9 @@
             typeElem = object.get(Message.TYPE_KEY);
             sequenceElem = object.get(Message.SEQUENCE_KEY);
             payloadElem = (JsonObject)object.get(Message.PAYLOAD_KEY);
+            actionElem = object.get(Message.ACTION_KEY);
+            jvmIdElem = object.get(Message.JVM_ID_KEY);
+            systemIdElem = object.get(Message.SYSTEM_ID_KEY);
         }
 
         MessageType getMessageType() {
@@ -170,6 +189,18 @@
             return sequenceElem;
         }
 
+        JsonElement getActionElement() {
+            return actionElem;
+        }
+
+        JsonElement getJvmIdElement() {
+            return jvmIdElem;
+        }
+
+        JsonElement getSystemIdElement() {
+            return systemIdElem;
+        }
+
         JsonObject getPayloadElement() {
             return payloadElem;
         }
--- a/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/receiver/ReceiverRegistry.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/agent/receiver/ReceiverRegistry.java	Fri Jul 28 12:16:49 2017 +0200
@@ -41,7 +41,7 @@
 import com.redhat.thermostat.common.utils.ServiceRegistry;
 
 /**
- * Handles registering {@link RequestReceiver}s into the framework.
+ * Handles retrieving {@link RequestReceiver}s from the framework.
  */
 public class ReceiverRegistry {
 
@@ -51,15 +51,15 @@
         proxy = new ServiceRegistry<RequestReceiver>(context, RequestReceiver.class.getName());
     }
 
-    public void registerReceiver(RequestReceiver receiver) {
-        proxy.registerService(receiver, receiver.getClass().getName());
+    void registerReceiver(RequestReceiver receiver, String actionName) {
+        proxy.registerService(receiver, actionName);
     }
 
-    public RequestReceiver getReceiver(String clazz) {
-        return proxy.getService(clazz);
+    public RequestReceiver getReceiver(String actionName) {
+        return proxy.getService(actionName);
     }
 
-    public void unregisterReceivers() {
+    void unregisterReceivers() {
         proxy.unregisterAll();
     }
 }
--- a/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/model/AgentRequest.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/model/AgentRequest.java	Fri Jul 28 12:16:49 2017 +0200
@@ -36,21 +36,49 @@
 
 package com.redhat.thermostat.commands.model;
 
+import java.util.Objects;
 import java.util.SortedMap;
 
 /**
- * A Command Channel Request relayed to an agent (a.k.a receiver).
+ * A Command Channel Request relayed to an agent (a.k.a receiver). An agent
+ * message contains additional information which might be security relevant.
+ *
+ * In particular, {@code action}, {@code jvmId} and {@code systemId} are
+ * strings which have been looked at by the authorization layer and can, thus,
+ * be trusted. If there was an authorization problem, no agent request would
+ * have been created.
  */
 public class AgentRequest extends WebSocketRequest implements Message {
-	
-	public static final String RECEIVER_PARAM_NAME = "receiver";
+
+    private final String action;
+    private final String jvmId;
+    private final String systemId;
 
-    public AgentRequest(long sequence, SortedMap<String, String> params) {
+    public AgentRequest(long sequence,
+                        String action,
+                        String systemId,
+                        String jvmId,
+                        SortedMap<String, String> params) {
         super(sequence, params);
+        this.action = Objects.requireNonNull(action);
+        this.jvmId = Objects.requireNonNull(jvmId);
+        this.systemId = Objects.requireNonNull(systemId);
     }
 
     @Override
     public MessageType getMessageType() {
         return MessageType.AGENT_REQUEST;
     }
+
+    public String getAction() {
+        return action;
+    }
+
+    public String getJvmId() {
+        return jvmId;
+    }
+
+    public String getSystemId() {
+        return systemId;
+    }
 }
--- a/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/model/Message.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/commands/agent/src/main/java/com/redhat/thermostat/commands/model/Message.java	Fri Jul 28 12:16:49 2017 +0200
@@ -45,6 +45,9 @@
     public static final String PAYLOAD_KEY = "payload";
     public static final String TYPE_KEY = "type";
     public static final String SEQUENCE_KEY = "sequence";
+    public static final String ACTION_KEY = "action";
+    public static final String SYSTEM_ID_KEY = "systemId";
+    public static final String JVM_ID_KEY = "jvmId";
 
     public MessageType getMessageType();
 
--- a/plugins/commands/agent/src/test/java/com/redhat/thermostat/commands/agent/internal/receiver/PingReceiverTest.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/commands/agent/src/test/java/com/redhat/thermostat/commands/agent/internal/receiver/PingReceiverTest.java	Fri Jul 28 12:16:49 2017 +0200
@@ -51,9 +51,19 @@
     @Test
     public void canReceiveWithProperSequence() {
         PingReceiver receiver = new PingReceiver();
-        AgentRequest req = new AgentRequest(123L, new TreeMap<String, String>());
+        AgentRequest req = new AgentRequest(123L, "ping", "system_id", "jvm_id", new TreeMap<String, String>());
         WebSocketResponse resp = receiver.receive(req);
         assertEquals(ResponseType.OK, resp.getResponseType());
         assertEquals(123L, resp.getSequenceId());
     }
+    
+    @Test
+    public void badActionReturnsError() {
+        String badAction = "not_ping";
+        PingReceiver receiver = new PingReceiver();
+        AgentRequest req = new AgentRequest(199L, badAction, "system_id", "jvm_id", new TreeMap<String, String>());
+        WebSocketResponse resp = receiver.receive(req);
+        assertEquals(ResponseType.ERROR, resp.getResponseType());
+        assertEquals(199L, resp.getSequenceId());
+    }
 }
--- a/plugins/commands/agent/src/test/java/com/redhat/thermostat/commands/agent/internal/socket/AgentSocketOnMessageCallbackTest.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/commands/agent/src/test/java/com/redhat/thermostat/commands/agent/internal/socket/AgentSocketOnMessageCallbackTest.java	Fri Jul 28 12:16:49 2017 +0200
@@ -40,6 +40,7 @@
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.mockito.Matchers.eq;
 
 import java.io.IOException;
 import java.util.SortedMap;
@@ -83,37 +84,63 @@
      */
     @Test
     public void handlesAgentRequestsProperly() throws InterruptedException {
-        CountDownLatch latch = new CountDownLatch(1);
-        TestReceiverRegistry reg = new TestReceiverRegistry(latch);
+        final CountDownLatch latch = new CountDownLatch(1);
+        final String jvmId = "jvm_id";
+        final String systemId = "system_id";
+        final long sequenceId = 333L;
+        final String actionName = "foo-action";
+        ReceiverRegistry reg = mock(ReceiverRegistry.class);
+        RequestReceiver receiver = new PingReceiver() {
+            
+            @Override
+            public WebSocketResponse receive(AgentRequest request) {
+                WebSocketResponse resp = super.receive(request);
+                assertEquals(jvmId, request.getJvmId());
+                assertEquals(systemId, request.getSystemId());
+                assertEquals(sequenceId, request.getSequenceId());
+                latch.countDown();
+                return resp;
+            }
+            
+        };
+        when(reg.getReceiver(eq(actionName))).thenReturn(receiver);
         AgentSocketOnMessageCallback cb = new AgentSocketOnMessageCallback(reg);
         Session session = mock(Session.class);
         when(session.getRemote()).thenReturn(mock(RemoteEndpoint.class)); // Prevent spurious NPEs
         SortedMap<String, String> params = new TreeMap<>();
-        String receiverName = "foo-bar";
-        params.put("receiver", receiverName);
-        AgentRequest agentRequest = new AgentRequest(333L, params);
+        AgentRequest agentRequest = new AgentRequest(333L, actionName, systemId, jvmId, params);
         
         // Main method under test
         cb.run(session, agentRequest, gson);
         
-        // wait for request to be handled
+        // wait for request to be handled. Assertions are done in the
+        // receiver.
         latch.await();
-        assertEquals(receiverName, reg.clazz);
-        assertEquals(333L, reg.response.getSequenceId());
     }
     
     @Test
     public void handlerThreadSendsResponseToSession() throws InterruptedException, IOException {
         ArgumentCaptor<String> jsonCaptor = ArgumentCaptor.forClass(String.class);
-        CountDownLatch receiverHandled = new CountDownLatch(1);
-        CountDownLatch sentLatch = new CountDownLatch(1);
-        TestReceiverRegistry reg = new TestReceiverRegistry(receiverHandled);
+        final CountDownLatch receiverHandled = new CountDownLatch(1);
+        final CountDownLatch sentLatch = new CountDownLatch(1);
+        final String actionName = "ping"; // must be ping otherwise receiver will return ERROR
+        ReceiverRegistry reg = mock(ReceiverRegistry.class);
+        RequestReceiver receiver = new PingReceiver() {
+            
+            @Override
+            public WebSocketResponse receive(AgentRequest request) {
+                WebSocketResponse resp = super.receive(request);
+                receiverHandled.countDown();
+                return resp;
+            }
+            
+        };
+        when(reg.getReceiver(eq(actionName))).thenReturn(receiver);
         Session session = mock(Session.class);
         RemoteEndpoint mockEndpoint = mock(RemoteEndpoint.class);
         when(session.getRemote()).thenReturn(mockEndpoint);
         SortedMap<String, String> params = new TreeMap<>();
-        params.put("receiver", "ignored");
-        AgentRequest agentRequest = new AgentRequest(344L, params);
+        AgentRequest agentRequest = new AgentRequest(344L, actionName, "system_id", "jvm_id", params);
         
         CmdChannelRequestHandler handler = new CmdChannelRequestHandler(session, agentRequest, reg, gson, sentLatch);
         handler.start(); // start asynchronously
@@ -141,7 +168,7 @@
         RemoteEndpoint mockEndpoint = mock(RemoteEndpoint.class);
         when(session.getRemote()).thenReturn(mockEndpoint);
         SortedMap<String, String> emptyParams = new TreeMap<>();
-        AgentRequest agentRequest = new AgentRequest(888L, emptyParams);
+        AgentRequest agentRequest = new AgentRequest(888L, "not-exist", "system_id", "jvm_id", emptyParams);
         
         CmdChannelRequestHandler handler = new CmdChannelRequestHandler(session, agentRequest, mock(ReceiverRegistry.class), gson, sentLatch);
         handler.start(); // start asynchronously
@@ -174,30 +201,4 @@
         AgentSocketOnMessageCallback cb = new AgentSocketOnMessageCallback(mock(ReceiverRegistry.class));
         cb.run(null, request, gson); // throws exception
     }
-
-    static class TestReceiverRegistry extends ReceiverRegistry {
-
-        private final CountDownLatch latch;
-        private String clazz;
-        private WebSocketResponse response;
-
-        TestReceiverRegistry(CountDownLatch latch) {
-            super(null);
-            this.latch = latch;
-        }
-
-        @Override
-        public RequestReceiver getReceiver(String clazz) {
-            this.clazz = clazz;
-            return new PingReceiver() {
-                @Override
-                public WebSocketResponse receive(AgentRequest request) {
-                    WebSocketResponse resp = super.receive(request);
-                    response = resp;
-                    latch.countDown();
-                    return resp;
-                }
-            };
-        }
-    }
 }
--- a/plugins/commands/agent/src/test/java/com/redhat/thermostat/commands/agent/receiver/ReceiverRegistryTest.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/commands/agent/src/test/java/com/redhat/thermostat/commands/agent/receiver/ReceiverRegistryTest.java	Fri Jul 28 12:16:49 2017 +0200
@@ -58,16 +58,17 @@
         StubBundleContext context = new StubBundleContext();
         ReceiverRegistry reg = new ReceiverRegistry(context);
         RequestReceiver receiver = mock(RequestReceiver.class);
+        String actionName = "foo-action";
         
-        reg.registerReceiver(receiver);
+        reg.registerReceiver(receiver, actionName);
         
         context.isServiceRegistered(RequestReceiver.class.getName(), receiver.getClass());
-        Collection<?> services = context.getServiceReferences(RequestReceiver.class, String.format(FILTER_FORMAT, receiver.getClass().getName()));
+        Collection<?> services = context.getServiceReferences(RequestReceiver.class, String.format(FILTER_FORMAT, actionName));
         assertEquals(1, services.size());
         @SuppressWarnings("unchecked")
         ServiceReference<RequestReceiver> sr = (ServiceReference<RequestReceiver>)services.iterator().next();
         String serviceName = (String)sr.getProperty("servicename");
-        assertEquals(receiver.getClass().getName(), serviceName);
+        assertEquals(actionName, serviceName);
     }
     
     @Test
@@ -75,10 +76,11 @@
         StubBundleContext context = new StubBundleContext();
         ReceiverRegistry reg = new ReceiverRegistry(context);
         assertNull(reg.getReceiver(String.class.getName()));
+        String actionName = "kill-vm";
         
         RequestReceiver receiver = mock(RequestReceiver.class);
-        reg.registerReceiver(receiver);
-        RequestReceiver actual = reg.getReceiver(receiver.getClass().getName());
+        reg.registerReceiver(receiver, actionName);
+        RequestReceiver actual = reg.getReceiver(actionName);
         assertSame(receiver, actual);
     }
 }
--- a/plugins/killvm/agent/src/main/java/com/redhat/thermostat/killvm/agent/internal/KillVmReceiver.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/killvm/agent/src/main/java/com/redhat/thermostat/killvm/agent/internal/KillVmReceiver.java	Fri Jul 28 12:16:49 2017 +0200
@@ -55,9 +55,10 @@
 
 @Component
 @Service(value = RequestReceiver.class)
-@Property(name = "servicename", value = "com.redhat.thermostat.killvm.agent.internal.KillVmReceiver")
+@Property(name = "servicename", value = KillVmReceiver.ACTION_NAME)
 public class KillVmReceiver implements RequestReceiver {
 
+    public static final String ACTION_NAME = "kill_vm";
     private static final Logger log = LoggingUtils.getLogger(KillVmReceiver.class);
     
     @Reference
@@ -70,6 +71,11 @@
     
     @Override
     public WebSocketResponse receive(AgentRequest request) {
+        // Sanity check. We should never get requests outside our action domain.
+        if (!ACTION_NAME.equals(request.getAction())) {
+            log.severe("Received action '" + request.getAction() + "' for receiver '" + ACTION_NAME + "'");
+            return new WebSocketResponse(request.getSequenceId(), ResponseType.ERROR);
+        }
         if (processService == null) {
             // no dice, should have service by now
             log.severe("Process service is null!");
--- a/plugins/killvm/agent/src/test/java/com/redhat/thermostat/killvm/agent/internal/KillVmReceiverTest.java	Mon Jul 31 12:43:10 2017 +0200
+++ b/plugins/killvm/agent/src/test/java/com/redhat/thermostat/killvm/agent/internal/KillVmReceiverTest.java	Fri Jul 28 12:16:49 2017 +0200
@@ -37,10 +37,8 @@
 package com.redhat.thermostat.killvm.agent.internal;
 
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
 import static org.mockito.Mockito.mock;
 
-import java.lang.reflect.Method;
 import java.util.SortedMap;
 import java.util.TreeMap;
 
@@ -60,7 +58,7 @@
         receiver.bindProcessService(proc);
         SortedMap<String, String> params = new TreeMap<>();
         params.put("vm-pid", "12345");
-        AgentRequest req = new AgentRequest(322, params);
+        AgentRequest req = new AgentRequest(322, KillVmReceiver.ACTION_NAME, "some_systemId", "some_jvmId", params);
         WebSocketResponse response = receiver.receive(req);
         assertEquals(ResponseType.OK, response.getResponseType());
         assertEquals(322, response.getSequenceId());
@@ -72,7 +70,7 @@
         KillVmReceiver receiver = new KillVmReceiver();
         receiver.bindProcessService(proc);
         SortedMap<String, String> params = new TreeMap<>();
-        AgentRequest req = new AgentRequest(-1, params);
+        AgentRequest req = new AgentRequest(-1, KillVmReceiver.ACTION_NAME, "some_systemId", "some_jvmId", params);
         WebSocketResponse response = receiver.receive(req);
         assertEquals(ResponseType.ERROR, response.getResponseType());
         assertEquals(-1, response.getSequenceId());
@@ -85,7 +83,7 @@
         receiver.bindProcessService(proc);
         SortedMap<String, String> params = new TreeMap<>();
         params.put("vm-pid", "hi");
-        AgentRequest req = new AgentRequest(211, params);
+        AgentRequest req = new AgentRequest(211, KillVmReceiver.ACTION_NAME, "some_systemId", "some_jvmId", params);
         WebSocketResponse response = receiver.receive(req);
         assertEquals(ResponseType.ERROR, response.getResponseType());
         assertEquals(211, response.getSequenceId());
@@ -95,44 +93,25 @@
     public void receiverReturnsErrorNoProcessHandler() {
         KillVmReceiver receiver = new KillVmReceiver();
         SortedMap<String, String> params = new TreeMap<>();
-        AgentRequest req = new AgentRequest(11, params);
+        AgentRequest req = new AgentRequest(11, KillVmReceiver.ACTION_NAME, "some_systemId", "some_jvmId", params);
         WebSocketResponse response = receiver.receive(req);
         assertEquals(ResponseType.ERROR, response.getResponseType());
         assertEquals(11, response.getSequenceId());
     }
 
     /**
-     * When a request is issued the fully qualified receiver class name is set
-     * via the 'receiver' param name. This test makes sure that this
-     * class is actually where it's supposed to be.
-     * 
-     * @throws Exception
+     * The expected action is kill_vm. If invoked via a different one, we
+     * expect an error response.
      */
     @Test
-    public void killVmReceiverIsInAppropriatePackage() {
-        Class<?> receiver = null;
-        try {
-            // com.redhat.thermostat.client.killvm.internal.KillVMAction uses
-            // this class name.
-            receiver = Class
-                    .forName("com.redhat.thermostat.killvm.agent.internal.KillVmReceiver");
-        } catch (ClassNotFoundException e) {
-            fail("com.redhat.thermostat.agent.killvm.internal.KillVmReceiver class not found, but used by some request!");
-        }
-        try {
-            ProcessHandler service = mock(ProcessHandler.class);
-            Object instance = receiver.newInstance();
-            Method bind = receiver.getDeclaredMethod("bindProcessService", ProcessHandler.class);
-            bind.invoke(instance, service);
-            Method m = receiver.getMethod("receive", AgentRequest.class);
-            SortedMap<String, String> params = new TreeMap<>();
-            params.put("vm-pid", "12345");
-            AgentRequest req = new AgentRequest(322, params);
-            m.invoke(instance, req);
-        } catch (Exception e) {
-            e.printStackTrace();
-            fail("cannot invoke receiver's receive method");
-        }
+    public void receiverReturnsErrorWhenBadAction() {
+        String unexpectedActionName = "foo-bar";
+        KillVmReceiver receiver = new KillVmReceiver();
+        SortedMap<String, String> params = new TreeMap<>();
+        AgentRequest req = new AgentRequest(13, unexpectedActionName, "some_systemId", "some_jvmId", params);
+        WebSocketResponse response = receiver.receive(req);
+        assertEquals(ResponseType.ERROR, response.getResponseType());
+        assertEquals(13, response.getSequenceId());
     }
 }