changeset 1608:de525fbb26cf

Set shared data structures as ServletContext attributes. Reviewed-by: vanaltj Review-thread: http://icedtea.classpath.org/pipermail/thermostat/2014-December/012086.html PR2059
author Severin Gehwolf <sgehwolf@redhat.com>
date Mon, 08 Dec 2014 14:56:05 +0100
parents 605f0c269f1c
children 6e5c85ab43dc
files web/server/src/main/java/com/redhat/thermostat/web/server/WebStorageEndPoint.java web/server/src/test/java/com/redhat/thermostat/web/server/WebStorageEndPointUnitTest.java
diffstat 2 files changed, 38 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/web/server/src/main/java/com/redhat/thermostat/web/server/WebStorageEndPoint.java	Mon Dec 01 10:43:12 2014 +0100
+++ b/web/server/src/main/java/com/redhat/thermostat/web/server/WebStorageEndPoint.java	Mon Dec 08 14:56:05 2014 +0100
@@ -131,10 +131,10 @@
     static final String FILES_WRITE_GRANT_ROLE_PREFIX = "thermostat-files-grant-write-filename-";
     private static final String TOKEN_MANAGER_TIMEOUT_PARAM = "token-manager-timeout";
     private static final String TOKEN_MANAGER_KEY = "token-manager";
-    private static final String USER_PRINCIPAL_CALLBACK_KEY = "user-principal-callback";
     private static final String CURSOR_MANAGER_KEY = "cursor-manager";
     static final String CATEGORY_MANAGER_KEY = "category-manager";
     static final String PREPARED_STMT_MANAGER_KEY = "prepared-stmt-manager";
+    static final String SERVER_TOKEN_KEY = "server-token";
     private static final int UNKNOWN_CURSOR_ID = -0xdeadbeef;
 
     // our strings can contain non-ASCII characters. Use UTF-8
@@ -152,12 +152,12 @@
     public static final String STORAGE_PASSWORD = "storage.password";
     public static final String STORAGE_CLASS = "storage.class";
     
-    private UUID serverToken;
-    
     // read-only set of all known statement descriptors we trust and allow
     private Set<String> knownStatementDescriptors;
     // read-only set of all known categories which we allow to get registered.
     private Set<String> knownCategoryNames;
+    // the principal callback used for retrieving the JAAS user principal
+    private PrincipalCallback principalCallback;
 
     @Override
     public void init(ServletConfig config) throws ServletException {
@@ -177,13 +177,6 @@
                 .registerTypeAdapterFactory(new WebPreparedStatementTypeAdapterFactory())
                 .registerTypeAdapterFactory(new PreparedParametersTypeAdapterFactory())
                 .create();
-        TokenManager tokenManager = new TokenManager();
-        String timeoutParam = getInitParameter(TOKEN_MANAGER_TIMEOUT_PARAM);
-        if (timeoutParam != null) {
-            tokenManager.setTimeout(Integer.parseInt(timeoutParam));
-        }
-        ServletContext servletContext = getServletContext();
-        servletContext.setAttribute(TOKEN_MANAGER_KEY, tokenManager);
         
         // Set the set of statement descriptors which we trust
         KnownDescriptorRegistry descRegistry = KnownDescriptorRegistryFactory.getInstance();
@@ -192,18 +185,27 @@
         KnownCategoryRegistry categoryRegistry = KnownCategoryRegistryFactory.getInstance();
         knownCategoryNames = categoryRegistry.getRegisteredCategoryNames();
         
-        // finally set callback for retrieving our JAAS user principal
+        ServletContext servletContext = getServletContext();
+        
         String serverInfo = servletContext.getServerInfo();
         ServletContainerInfoFactory factory = new ServletContainerInfoFactory(serverInfo);
         ServletContainerInfo info = factory.getInfo();
         PrincipalCallbackFactory cbFactory = new PrincipalCallbackFactory(info);
-        PrincipalCallback callback = Objects.requireNonNull(cbFactory.getCallback());
-        servletContext.setAttribute(USER_PRINCIPAL_CALLBACK_KEY, callback);
+        principalCallback = Objects.requireNonNull(cbFactory.getCallback());
+        
+        TokenManager tokenManager = new TokenManager();
+        String timeoutParam = getInitParameter(TOKEN_MANAGER_TIMEOUT_PARAM);
+        if (timeoutParam != null) {
+            tokenManager.setTimeout(Integer.parseInt(timeoutParam));
+        }
+        // The following get set as servlet context attributes in order
+        // to support clustered deployments.
         synchronized(servletContext) {
+            servletContext.setAttribute(TOKEN_MANAGER_KEY, tokenManager);
             servletContext.setAttribute(CATEGORY_MANAGER_KEY, new CategoryManager());
             servletContext.setAttribute(PREPARED_STMT_MANAGER_KEY, new PreparedStatementManager());
+            servletContext.setAttribute(SERVER_TOKEN_KEY, UUID.randomUUID());
         }
-        serverToken = UUID.randomUUID();
     }
     
     @Override
@@ -328,6 +330,7 @@
         // malicious client which sends a bad token on purpose. In either case
         // it should be OK to solely send back a distinct error code indicating
         // this situation.
+        final UUID serverToken = getServerToken();
         if (!serverToken.equals(catId.getServerToken())) {
             logger.log(Level.INFO, "Server token: '" + serverToken +
                     "' and client token '" + catId.getServerToken() +
@@ -558,7 +561,7 @@
                 category = gson.fromJson(categoryParam, Category.class);
                 storage.registerCategory(category);
             }
-            id = catManager.putCategory(serverToken, category, catIdentifier);
+            id = catManager.putCategory(getServerToken(), category, catIdentifier);
             if (isAggregateCat) {
                 logger.log(Level.FINEST, "(id: " + id.getId() + ") did not register aggregate category " + category );
             } else {
@@ -599,6 +602,7 @@
         // Check if the server token the client knows about still matches.
         // Bail out early otherwise.
         SharedStateId stmtId = stmt.getStatementId();
+        final UUID serverToken = getServerToken();
         if (!serverToken.equals(stmtId.getServerToken())) {
             logger.log(Level.INFO, "Server token: '" + serverToken +
                                    "' and client token '" + stmtId.getServerToken() +
@@ -680,6 +684,10 @@
         return Objects.requireNonNull(attributeVal);
     }
     
+    private UUID getServerToken() {
+        return getServletContextAttribute(SERVER_TOKEN_KEY);
+    }
+    
     private CategoryManager getCategoryManager() {
         return getServletContextAttribute(CATEGORY_MANAGER_KEY);
     }
@@ -817,6 +825,7 @@
         // Check if the server token the client knows about still matches.
         // Bail out early otherwise.
         SharedStateId stmtId = stmt.getStatementId();
+        final UUID serverToken = getServerToken();
         if (!serverToken.equals(stmtId.getServerToken())) {
             logger.log(Level.INFO, "Server token: '" + serverToken +
                                    "' and client token '" + stmtId.getServerToken() +
@@ -848,10 +857,7 @@
     
     private UserPrincipal getUserPrincipal(HttpServletRequest req) {
         Principal principal = req.getUserPrincipal();
-        
-        ServletContext context = getServletContext();
-        PrincipalCallback callback = (PrincipalCallback)context.getAttribute(USER_PRINCIPAL_CALLBACK_KEY);
-        return callback.getUserPrincipal(principal);
+        return principalCallback.getUserPrincipal(principal);
     }
 
     /*
--- a/web/server/src/test/java/com/redhat/thermostat/web/server/WebStorageEndPointUnitTest.java	Mon Dec 01 10:43:12 2014 +0100
+++ b/web/server/src/test/java/com/redhat/thermostat/web/server/WebStorageEndPointUnitTest.java	Mon Dec 08 14:56:05 2014 +0100
@@ -54,6 +54,7 @@
 import java.nio.file.attribute.FileAttribute;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.UUID;
 
 import javax.servlet.ServletConfig;
 import javax.servlet.ServletContext;
@@ -227,9 +228,18 @@
         ThCreatorResult result = creatWorkingThermostatHome();
         System.setProperty(TH_HOME_PROP_NAME, result.thermostatHome.toFile().getAbsolutePath());
         endpoint.init(config);
-        ArgumentCaptor<CategoryManager> managerCaptor = ArgumentCaptor.forClass(CategoryManager.class);
-        verify(mockContext).setAttribute(eq("category-manager"), managerCaptor.capture());
-        assertNotNull(managerCaptor.getValue());
+        ArgumentCaptor<CategoryManager> categoryManagerCaptor = ArgumentCaptor.forClass(CategoryManager.class);
+        ArgumentCaptor<PreparedStatementManager> prepStmtManagerCaptor = ArgumentCaptor.forClass(PreparedStatementManager.class);
+        ArgumentCaptor<TokenManager> tokenManagerCaptor = ArgumentCaptor.forClass(TokenManager.class);
+        ArgumentCaptor<UUID> serverTokenCaptor = ArgumentCaptor.forClass(UUID.class);
+        verify(mockContext).setAttribute(eq("category-manager"), categoryManagerCaptor.capture());
+        verify(mockContext).setAttribute(eq("prepared-stmt-manager"), prepStmtManagerCaptor.capture());
+        verify(mockContext).setAttribute(eq("token-manager"), tokenManagerCaptor.capture());
+        verify(mockContext).setAttribute(eq("server-token"), serverTokenCaptor.capture());
+        assertNotNull(categoryManagerCaptor.getValue());
+        assertNotNull(prepStmtManagerCaptor.getValue());
+        assertNotNull(tokenManagerCaptor.getValue());
+        assertNotNull(serverTokenCaptor.getValue());
     }
     
     private ThCreatorResult creatWorkingThermostatHome() throws IOException {