changeset 1288:d09fd25379e8

Fix prepared statement patching. Reviewed-by: omajid Review-thread: http://icedtea.classpath.org/pipermail/thermostat/2013-October/008540.html
author Severin Gehwolf <sgehwolf@redhat.com>
date Tue, 22 Oct 2013 17:48:37 +0200
parents 4b1b5692a596
children 7a0ff324a833
files storage/core/src/main/java/com/redhat/thermostat/storage/core/Statement.java storage/core/src/main/java/com/redhat/thermostat/storage/internal/statement/ParsedStatementImpl.java storage/core/src/test/java/com/redhat/thermostat/storage/internal/statement/ParsedStatementImplTest.java storage/core/src/test/java/com/redhat/thermostat/storage/internal/statement/PreparedStatementImplTest.java storage/mongo/src/main/java/com/redhat/thermostat/storage/mongodb/internal/MongoQuery.java storage/mongo/src/main/java/com/redhat/thermostat/storage/mongodb/internal/MongoStorage.java web/server/src/main/java/com/redhat/thermostat/web/server/WebStorageEndPoint.java
diffstat 7 files changed, 168 insertions(+), 33 deletions(-) [+]
line wrap: on
line diff
--- a/storage/core/src/main/java/com/redhat/thermostat/storage/core/Statement.java	Tue Oct 22 19:08:05 2013 +0200
+++ b/storage/core/src/main/java/com/redhat/thermostat/storage/core/Statement.java	Tue Oct 22 17:48:37 2013 +0200
@@ -3,10 +3,23 @@
 import com.redhat.thermostat.storage.model.Pojo;
 
 /**
- * Marker interface for all operations on storage. This includes queries and
- * statements manipulating data.
- *
+ * Implementations of this interface represent operations on storage. This
+ * includes queries and statements manipulating data.
+ * 
+ * @see BackingStorage
+ * @see Query
+ * @see Update
+ * @see Replace
+ * @see Add
+ * @see Remove
  */
 public interface Statement<T extends Pojo> {
 
+    /**
+     * Produces a copy of this statement as if it was just created with the
+     * corresponding factory method in {@link BackingStorage}.
+     * 
+     * @return A new raw instance of this statement.
+     */
+    Statement<T> getRawDuplicate();
 }
--- a/storage/core/src/main/java/com/redhat/thermostat/storage/internal/statement/ParsedStatementImpl.java	Tue Oct 22 19:08:05 2013 +0200
+++ b/storage/core/src/main/java/com/redhat/thermostat/storage/internal/statement/ParsedStatementImpl.java	Tue Oct 22 17:48:37 2013 +0200
@@ -78,16 +78,23 @@
             IllegalStateException expn = new IllegalStateException(msg);
             throw new IllegalPatchException(expn);
         }
-        patchSetList(params);
-        patchWhere(params);
-        patchSort(params);
-        patchLimit(params);
+        
+        /*
+         * Statements may not be stateless and hence, we need to create a
+         * duplicate prior every patch + execution.
+         */
+        Statement<T> stmt = statement.getRawDuplicate();
+        
+        patchSetList(stmt, params);
+        patchWhere(stmt, params);
+        patchSort(stmt, params);
+        patchLimit(stmt, params);
         // TODO count actual patches and throw an exception if not all vars
         // have been patched up.
-        return statement;
+        return stmt;
     }
 
-    private void patchSetList(PreparedParameter[] params) throws IllegalPatchException {
+    private void patchSetList(Statement<T> stmt, PreparedParameter[] params) throws IllegalPatchException {
         if (setList.getValues().size() == 0) {
             // no set list, nothing to do
             return;
@@ -95,66 +102,66 @@
         // do the patching
         PatchedSetList patchedSetList = setList.patch(params);
         // set the values
-        if (statement instanceof Add) {
-            Add<T> add = (Add<T>)statement;
+        if (stmt instanceof Add) {
+            Add<T> add = (Add<T>)stmt;
             for (PatchedSetListMember member: patchedSetList.getSetListMembers()) {
                 add.set(member.getKey().getName(), member.getValue());
             }
         }
-        if (statement instanceof Replace) {
-            Replace<T> replace = (Replace<T>)statement;
+        if (stmt instanceof Replace) {
+            Replace<T> replace = (Replace<T>)stmt;
             for (PatchedSetListMember member: patchedSetList.getSetListMembers()) {
                 replace.set(member.getKey().getName(), member.getValue());
             }
         }
-        if (statement instanceof Update) {
-            Update<T> update = (Update<T>)statement;
+        if (stmt instanceof Update) {
+            Update<T> update = (Update<T>)stmt;
             for (PatchedSetListMember member: patchedSetList.getSetListMembers()) {
                 update.set(member.getKey().getName(), member.getValue());
             }
         }
     }
 
-    private void patchLimit(PreparedParameter[] params) throws IllegalPatchException {
+    private void patchLimit(Statement<T> stmt, PreparedParameter[] params) throws IllegalPatchException {
         LimitExpression expn = suffixExpn.getLimitExpn();
         if (expn == null) {
             // no limit expn, nothing to do
             return;
         }
         PatchedLimitExpression patchedExp = expn.patch(params);
-        if (statement instanceof Query) {
-            Query<T> query = (Query<T>) statement;
+        if (stmt instanceof Query) {
+            Query<T> query = (Query<T>) stmt;
             query.limit(patchedExp.getLimitValue());
         } else {
             String msg = "Patching 'limit' of non-query types not supported! Class was:"
-                    + statement.getClass().getName();
+                    + stmt.getClass().getName();
             IllegalStateException invalid = new IllegalStateException(msg);
             throw new IllegalPatchException(invalid);
         }
     }
 
-    private void patchSort(PreparedParameter[] params) throws IllegalPatchException {
+    private void patchSort(Statement<T> stmt, PreparedParameter[] params) throws IllegalPatchException {
         SortExpression expn = suffixExpn.getSortExpn();
         if (expn == null) {
             // no sort expn, nothing to do
             return;
         }
         PatchedSortExpression patchedExp = expn.patch(params);
-        if (statement instanceof Query) {
-            Query<T> query = (Query<T>) statement;
+        if (stmt instanceof Query) {
+            Query<T> query = (Query<T>) stmt;
             PatchedSortMember[] members = patchedExp.getSortMembers();
             for (int i = 0; i < members.length; i++) {
                 query.sort(members[i].getSortKey(), members[i].getDirection());
             }
         } else {
             String msg = "Patching 'sort' of non-query types not supported! Class was:"
-                    + statement.getClass().getName();
+                    + stmt.getClass().getName();
             IllegalStateException invalid = new IllegalStateException(msg);
             throw new IllegalPatchException(invalid);
         }
     }
 
-    private void patchWhere(PreparedParameter[] params) throws IllegalPatchException {
+    private void patchWhere(Statement<T> stmt, PreparedParameter[] params) throws IllegalPatchException {
         WhereExpression expn = suffixExpn.getWhereExpn();
         if (expn == null) {
             // no where, nothing to do
@@ -164,21 +171,21 @@
         // the way.
         PatchedWhereExpression patchedExp = expn.patch(params);
         Expression whereClause = patchedExp.getExpression();
-        if (statement instanceof Query) {
-            Query<T> query = (Query<T>) statement;
+        if (stmt instanceof Query) {
+            Query<T> query = (Query<T>) stmt;
             query.where(whereClause);
-        } else if (statement instanceof Replace) {
-            Replace<T> replace = (Replace<T>) statement;
+        } else if (stmt instanceof Replace) {
+            Replace<T> replace = (Replace<T>) stmt;
             replace.where(whereClause);
-        } else if (statement instanceof Update) {
-            Update<T> update = (Update<T>) statement;
+        } else if (stmt instanceof Update) {
+            Update<T> update = (Update<T>) stmt;
             update.where(whereClause);
-        } else if (statement instanceof Remove) {
-            Remove<T> remove = (Remove<T>) statement;
+        } else if (stmt instanceof Remove) {
+            Remove<T> remove = (Remove<T>) stmt;
             remove.where(whereClause);
         } else {
             String msg = "Patching of where clause not supported! Class was:"
-                    + statement.getClass().getName();
+                    + stmt.getClass().getName();
             IllegalStateException invalid = new IllegalStateException(msg);
             throw new IllegalPatchException(invalid);
         }
--- a/storage/core/src/test/java/com/redhat/thermostat/storage/internal/statement/ParsedStatementImplTest.java	Tue Oct 22 19:08:05 2013 +0200
+++ b/storage/core/src/test/java/com/redhat/thermostat/storage/internal/statement/ParsedStatementImplTest.java	Tue Oct 22 17:48:37 2013 +0200
@@ -38,8 +38,12 @@
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNotSame;
+import static org.junit.Assert.assertSame;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 import java.util.ArrayList;
 import java.util.HashMap;
@@ -60,6 +64,7 @@
 import com.redhat.thermostat.storage.core.Query;
 import com.redhat.thermostat.storage.core.Query.SortDirection;
 import com.redhat.thermostat.storage.core.Replace;
+import com.redhat.thermostat.storage.core.Statement;
 import com.redhat.thermostat.storage.core.Update;
 import com.redhat.thermostat.storage.model.Pojo;
 import com.redhat.thermostat.storage.query.BinaryComparisonExpression;
@@ -85,6 +90,25 @@
     }
     
     @Test
+    public void patchingDuplicatesStatement() throws IllegalPatchException {
+        @SuppressWarnings("unchecked")
+        Statement<Pojo> stmt = (Statement<Pojo>)mock(Statement.class);
+        @SuppressWarnings("unchecked")
+        Statement<Pojo> mock2 = mock(Statement.class);
+        when(stmt.getRawDuplicate()).thenReturn(mock2);
+        ParsedStatementImpl<Pojo> parsedStmt = new ParsedStatementImpl<>(stmt);
+        SuffixExpression suffixExpn = new SuffixExpression();
+        suffixExpn.setLimitExpn(null);
+        suffixExpn.setSortExpn(null);
+        parsedStmt.setSetList(new SetList());
+        parsedStmt.setSuffixExpression(suffixExpn);
+        
+        Statement<Pojo> result = parsedStmt.patchStatement(new PreparedParameter[] {});
+        assertNotSame("Statement should get duplicated on patching", stmt, result);
+        assertSame(mock2, result);
+    }
+    
+    @Test
     public void canPatchWhereAndExpr() throws IllegalPatchException {
         // create the parsedStatementImpl we are going to use
         ParsedStatementImpl<Pojo> parsedStmt = new ParsedStatementImpl<>(statement);
@@ -817,6 +841,11 @@
             // Not implemented
             throw new AssertionError();
         }
+
+        @Override
+        public Statement<Pojo> getRawDuplicate() {
+            return new TestQuery();
+        }
         
     }
     
@@ -834,6 +863,11 @@
             // not implemented
             throw new AssertionError();
         }
+
+        @Override
+        public Statement<T> getRawDuplicate() {
+            return new TestAdd<>();
+        }
         
     }
     
@@ -857,6 +891,11 @@
             // not implemented
             throw new AssertionError();
         }
+
+        @Override
+        public Statement<TestPojo> getRawDuplicate() {
+            return new TestReplace();
+        }
         
     }
     
@@ -882,6 +921,11 @@
             // not implemented
             throw new AssertionError();
         }
+
+        @Override
+        public Statement<TestPojo> getRawDuplicate() {
+            return new TestUpdate();
+        }
         
     }
     
--- a/storage/core/src/test/java/com/redhat/thermostat/storage/internal/statement/PreparedStatementImplTest.java	Tue Oct 22 19:08:05 2013 +0200
+++ b/storage/core/src/test/java/com/redhat/thermostat/storage/internal/statement/PreparedStatementImplTest.java	Tue Oct 22 17:48:37 2013 +0200
@@ -62,6 +62,7 @@
 import com.redhat.thermostat.storage.core.Query;
 import com.redhat.thermostat.storage.core.Remove;
 import com.redhat.thermostat.storage.core.Replace;
+import com.redhat.thermostat.storage.core.Statement;
 import com.redhat.thermostat.storage.core.StatementDescriptor;
 import com.redhat.thermostat.storage.core.StatementExecutionException;
 import com.redhat.thermostat.storage.core.Update;
@@ -351,6 +352,12 @@
             return 0;
         }
 
+        @Override
+        public Statement<T> getRawDuplicate() {
+            // we don't duplicate for this test
+            return this;
+        }
+
     }
     
     private static class TestReplace implements Replace<FooPojo> {
@@ -374,6 +381,12 @@
             this.executed = true;
             return 0;
         }
+
+        @Override
+        public Statement<FooPojo> getRawDuplicate() {
+            // we don't duplicate for this test
+            return this;
+        }
         
     }
     
@@ -399,6 +412,12 @@
             this.executed = true;
             return 0;
         }
+
+        @Override
+        public Statement<FooPojo> getRawDuplicate() {
+            // we don't duplicate for this test
+            return this;
+        }
         
     }
     
@@ -417,6 +436,12 @@
             this.executed = true;
             return 0;
         }
+
+        @Override
+        public Statement<FooPojo> getRawDuplicate() {
+            // we don't duplicate for this test
+            return this;
+        }
         
     }
     
@@ -478,6 +503,12 @@
             // not implemented
             return null;
         }
+
+        @Override
+        public Statement<Pojo> getRawDuplicate() {
+            // For this test, we don't duplicate
+            return this;
+        }
         
     }
 }
--- a/storage/mongo/src/main/java/com/redhat/thermostat/storage/mongodb/internal/MongoQuery.java	Tue Oct 22 19:08:05 2013 +0200
+++ b/storage/mongo/src/main/java/com/redhat/thermostat/storage/mongodb/internal/MongoQuery.java	Tue Oct 22 17:48:37 2013 +0200
@@ -43,6 +43,7 @@
 import com.redhat.thermostat.storage.core.AbstractQuery;
 import com.redhat.thermostat.storage.core.Category;
 import com.redhat.thermostat.storage.core.Cursor;
+import com.redhat.thermostat.storage.core.Statement;
 import com.redhat.thermostat.storage.model.Pojo;
 import com.redhat.thermostat.storage.query.Expression;
 
@@ -121,5 +122,10 @@
         return expression;
     }
 
+    @Override
+    public Statement<T> getRawDuplicate() {
+        return new MongoQuery<>(storage, category);
+    }
+
 }
 
--- a/storage/mongo/src/main/java/com/redhat/thermostat/storage/mongodb/internal/MongoStorage.java	Tue Oct 22 19:08:05 2013 +0200
+++ b/storage/mongo/src/main/java/com/redhat/thermostat/storage/mongodb/internal/MongoStorage.java	Tue Oct 22 17:48:37 2013 +0200
@@ -71,6 +71,7 @@
 import com.redhat.thermostat.storage.core.Query;
 import com.redhat.thermostat.storage.core.Remove;
 import com.redhat.thermostat.storage.core.Replace;
+import com.redhat.thermostat.storage.core.Statement;
 import com.redhat.thermostat.storage.core.StatementDescriptor;
 import com.redhat.thermostat.storage.core.Update;
 import com.redhat.thermostat.storage.model.AggregateCount;
@@ -98,6 +99,12 @@
         public Cursor<T> execute() {
             return executeGetCount(category, (MongoQuery<T>)this.queryToAggregate);
         }
+
+        @Override
+        public Statement<T> getRawDuplicate() {
+            MongoQuery<T> query = (MongoQuery<T>) this.queryToAggregate;
+            return new MongoCountQuery<>(query, category);
+        }
     }
     
     private static abstract class MongoSetter<T extends Pojo> {
@@ -153,6 +160,11 @@
         public void set(String key, Object value) {
             super.set(key, value);
         }
+
+        @Override
+        public Statement<T> getRawDuplicate() {
+            return new MongoAdd<>(category);
+        }
         
     }
 
@@ -186,6 +198,11 @@
         public void set(String key, Object value) {
             super.set(key, value);
         }
+
+        @Override
+        public Statement<T> getRawDuplicate() {
+            return new MongoReplace<>(category);
+        }
         
     }
     
@@ -217,6 +234,11 @@
             DBObject setValues = new BasicDBObject(SET_MODIFIER, values);
             return updateImpl(category, setValues, query);
         }
+
+        @Override
+        public Statement<T> getRawDuplicate() {
+            return new MongoUpdate<>(category);
+        }
     }
     
     private class MongoRemove<T extends Pojo> implements Remove<T> {
@@ -243,6 +265,11 @@
         public int apply() {
             return removePojo(category, query);
         }
+
+        @Override
+        public Statement<T> getRawDuplicate() {
+            return new MongoRemove<>(category);
+        }
         
     }
 
--- a/web/server/src/main/java/com/redhat/thermostat/web/server/WebStorageEndPoint.java	Tue Oct 22 19:08:05 2013 +0200
+++ b/web/server/src/main/java/com/redhat/thermostat/web/server/WebStorageEndPoint.java	Tue Oct 22 17:48:37 2013 +0200
@@ -86,6 +86,7 @@
 import com.redhat.thermostat.storage.core.PreparedParameters;
 import com.redhat.thermostat.storage.core.PreparedStatement;
 import com.redhat.thermostat.storage.core.Query;
+import com.redhat.thermostat.storage.core.Statement;
 import com.redhat.thermostat.storage.core.StatementDescriptor;
 import com.redhat.thermostat.storage.core.Storage;
 import com.redhat.thermostat.storage.core.auth.DescriptorMetadata;
@@ -827,6 +828,12 @@
                 // must not be called.
                 throw new IllegalStateException();
             }
+
+            @Override
+            public Statement<T> getRawDuplicate() {
+                // must not be called.
+                throw new IllegalStateException();
+            }
             
         };
         return empty;