View Javadoc
1   package ejava.util.jpa;
2   
3   import java.io.File;
4   import java.io.FileInputStream;
5   import java.io.FileNotFoundException;
6   import java.io.IOException;
7   import java.io.InputStream;
8   import java.util.ArrayList;
9   import java.util.HashMap;
10  import java.util.LinkedList;
11  import java.util.List;
12  import java.util.Map;
13  import java.util.Map.Entry;
14  
15  import javax.persistence.EntityManager;
16  
17  import org.slf4j.Logger;
18  import org.slf4j.LoggerFactory;
19  
20  /**
21   * This class will execute a set of drop and create DDL scripts against
22   * a supplied entity manager.
23   */
24  public class DBUtil {
25      static Logger logger = LoggerFactory.getLogger(DBUtil.class);
26      protected EntityManager em;
27      protected List<String> dropPaths=new LinkedList<>();
28      protected List<String> createPaths=new LinkedList<>();
29      
30      public DBUtil() {}
31      public DBUtil(EntityManager em, String dropPath, String createPath) {
32          setEntityManager(em);
33          addDropPath(dropPath);
34          addCreatePath(createPath);
35      }
36  
37      public void setEntityManager(EntityManager em) {
38          this.em = em; 
39      }
40      public void addDropPath(String dropPath) {
41          this.dropPaths.add(dropPath);
42      }
43      public void addCreatePath(String createPath) {
44          this.createPaths.add(createPath);
45      }
46      
47      /**
48       * Locate the input stream either as a file path or a resource path.
49       * @param path
50       * @return input stream
51       * @throws IllegalArgumentException if path not found 
52       */
53      protected InputStream getInputStream(String path) throws IllegalArgumentException {
54          InputStream is = null;
55          
56          File pathFile = new File(path);
57          //try file system first
58          if (pathFile.exists()) {
59              try {
60                  is = new FileInputStream(pathFile);
61              } catch (FileNotFoundException ex) {
62                  throw new IllegalArgumentException(String.format("unable to open file %s", path));                
63              }
64          }
65          //then try classpath
66          else {
67             if ((is = Thread.currentThread()
68                        .getContextClassLoader()
69                        .getResourceAsStream(path)) == null) {
70                 throw new IllegalArgumentException(String.format("%s not found in classpath", path));
71             }
72          }
73          
74          return is; //inputStream should not be null at this point
75      }
76      
77      /**
78       * Turn the InputStream into a String for easier parsing for SQL statements.
79       * @param is
80       * @return string contents of SQL file
81       * @throws IOException 
82       */
83      protected String getString(InputStream is) throws IOException {
84          StringBuilder text = new StringBuilder();
85          byte[] buffer = new byte[4096];
86          for (int n; (n = is.read(buffer)) != -1;) {
87              text.append(new String(buffer, 0,n));
88          }
89          return text.toString();
90      }
91          
92      /**
93       * Get list of SQL statements from the supplied string.
94       * @param contents
95       * @return list of distinct statements
96       * @throws Exception
97       */
98      protected List<String> getStatements(String contents) {
99          List<String> statements = new ArrayList<String>();
100 
101         for (String tok: contents.split(";")) {
102             statements.add(tok.trim());
103         }
104         return statements;
105     }
106     
107     /**
108      * Execute the SQL statements contained within the resource that is located
109      * by the path supplied. This path can be either a file path or resource 
110      * path.
111      * @param path
112      * @return count of statements executed
113      * @throws Exception
114      */
115     public int executeScript(String path) {
116         if (path == null || path.length() == 0) {
117             throw new IllegalStateException("no path provided");
118         }
119         
120         InputStream is = getInputStream(path);
121         if (is == null) {
122             throw new IllegalStateException("path not found:" + path);
123         }
124         
125         try {
126             String sql = getString(is);
127             List<String> statements = getStatements(sql);
128             logger.debug("found {} statements", statements.size());
129             
130             for (String statement : statements) {
131                 logger.debug("executing:" + statement);
132                 em.createNativeQuery(statement).executeUpdate();    
133             }
134 
135             return statements.size(); 
136         } catch (IOException ex) {
137             throw new IllegalStateException("error parsing SQL file", ex);
138         }
139     }
140 
141     /**
142      * Return the current nextVal for each sequence used by application.
143      * @return map of sequence nextVals
144      */
145     public Map<String, Number> getSequenceNextVals() {
146         //find the names of sequences used
147         @SuppressWarnings("unchecked")
148         List<String> sequenceNames = em.createNativeQuery(
149                 "SELECT sequence_name FROM INFORMATION_SCHEMA.SEQUENCES "
150                 + "where sequence_name not like 'SYSTEM%'")
151                 .getResultList();
152         
153         Map<String, Number> sequenceVals = new HashMap<>();
154         //query for the next value that would be reported
155         for (String sequenceName: sequenceNames) {
156             Number nextValue=(Number)em.createNativeQuery(
157                     String.format("call nextval('%s')", sequenceName))
158                     .getSingleResult();
159           sequenceVals.put(sequenceName, nextValue);            
160         }
161                 
162         return sequenceVals;
163     }
164     
165     /**
166      * Updates the state of the provided sequence(s) to values provided
167      * @param sequenceVals
168      */
169     public void setSequenceNextVals(Map<String, Number> sequenceVals) {
170         for (Entry<String, Number> sequence: sequenceVals.entrySet()) {
171             em.createNativeQuery(
172                     String.format("alter sequence %s restart with ?", sequence.getKey()))
173                 .setParameter(1, sequence.getValue()).executeUpdate();
174         }        
175     }
176     
177     /**
178      * Execute the drop script against the DB.
179      * @return map of sequence previous nextVal state
180      * @throws RuntimeException on errors like schema does not currently 
181      * exist. In that case the caller must rollback the current transaction
182      * and begin a new one for the create();
183      */
184     public Map<String, Number> dropAll() throws RuntimeException {
185         try {
186             Map<String, Number> sequenceVals = getSequenceNextVals();
187             for (String script: dropPaths) {
188                 executeScript(script);
189             }
190             return sequenceVals;
191         } catch (Exception ex) {
192             throw new RuntimeException("error dropping DB, might not exist: " + ex);
193         }
194     }
195 
196     /**
197      * Execute the create script against the DB. Follow-up by altering the provided
198      * sequences to be their previous nextVal state. This is necessary when resetting schema
199      * while JPA provider has cached a block of sequence values from a previous execution.
200      * @return count of statements executed
201      * @throws 
202      */
203     public int createAll(Map<String, Number> sequenceVals) throws RuntimeException {
204         int count=0;
205         for (String script: createPaths) {
206             count += executeScript(script);
207         }
208         setSequenceNextVals(sequenceVals);
209         return count;
210     }    
211 }