1 package ejava.util.jpa;
2
3 import java.io.File;
4 import java.io.FileInputStream;
5 import java.io.FileNotFoundException;
6 import java.io.IOException;
7 import java.io.InputStream;
8 import java.util.ArrayList;
9 import java.util.HashMap;
10 import java.util.LinkedList;
11 import java.util.List;
12 import java.util.Map;
13 import java.util.Map.Entry;
14
15 import javax.persistence.EntityManager;
16
17 import org.slf4j.Logger;
18 import org.slf4j.LoggerFactory;
19
20
21
22
23
24 public class DBUtil {
25 static Logger logger = LoggerFactory.getLogger(DBUtil.class);
26 protected EntityManager em;
27 protected List<String> dropPaths=new LinkedList<>();
28 protected List<String> createPaths=new LinkedList<>();
29
30 public DBUtil() {}
31 public DBUtil(EntityManager em, String dropPath, String createPath) {
32 setEntityManager(em);
33 addDropPath(dropPath);
34 addCreatePath(createPath);
35 }
36
37 public void setEntityManager(EntityManager em) {
38 this.em = em;
39 }
40 public void addDropPath(String dropPath) {
41 this.dropPaths.add(dropPath);
42 }
43 public void addCreatePath(String createPath) {
44 this.createPaths.add(createPath);
45 }
46
47
48
49
50
51
52
53 protected InputStream getInputStream(String path) throws IllegalArgumentException {
54 InputStream is = null;
55
56 File pathFile = new File(path);
57
58 if (pathFile.exists()) {
59 try {
60 is = new FileInputStream(pathFile);
61 } catch (FileNotFoundException ex) {
62 throw new IllegalArgumentException(String.format("unable to open file %s", path));
63 }
64 }
65
66 else {
67 if ((is = Thread.currentThread()
68 .getContextClassLoader()
69 .getResourceAsStream(path)) == null) {
70 throw new IllegalArgumentException(String.format("%s not found in classpath", path));
71 }
72 }
73
74 return is;
75 }
76
77
78
79
80
81
82
83 protected String getString(InputStream is) throws IOException {
84 StringBuilder text = new StringBuilder();
85 byte[] buffer = new byte[4096];
86 for (int n; (n = is.read(buffer)) != -1;) {
87 text.append(new String(buffer, 0,n));
88 }
89 return text.toString();
90 }
91
92
93
94
95
96
97
98 protected List<String> getStatements(String contents) {
99 List<String> statements = new ArrayList<String>();
100
101 for (String tok: contents.split(";")) {
102 statements.add(tok.trim());
103 }
104 return statements;
105 }
106
107
108
109
110
111
112
113
114
115 public int executeScript(String path) {
116 if (path == null || path.length() == 0) {
117 throw new IllegalStateException("no path provided");
118 }
119
120 InputStream is = getInputStream(path);
121 if (is == null) {
122 throw new IllegalStateException("path not found:" + path);
123 }
124
125 try {
126 String sql = getString(is);
127 List<String> statements = getStatements(sql);
128 logger.debug("found {} statements", statements.size());
129
130 for (String statement : statements) {
131 logger.debug("executing:" + statement);
132 em.createNativeQuery(statement).executeUpdate();
133 }
134
135 return statements.size();
136 } catch (IOException ex) {
137 throw new IllegalStateException("error parsing SQL file", ex);
138 }
139 }
140
141
142
143
144
145 public Map<String, Number> getSequenceNextVals() {
146
147 @SuppressWarnings("unchecked")
148 List<String> sequenceNames = em.createNativeQuery(
149 "SELECT sequence_name FROM INFORMATION_SCHEMA.SEQUENCES "
150 + "where sequence_name not like 'SYSTEM%'")
151 .getResultList();
152
153 Map<String, Number> sequenceVals = new HashMap<>();
154
155 for (String sequenceName: sequenceNames) {
156 Number nextValue=(Number)em.createNativeQuery(
157 String.format("call nextval('%s')", sequenceName))
158 .getSingleResult();
159 sequenceVals.put(sequenceName, nextValue);
160 }
161
162 return sequenceVals;
163 }
164
165
166
167
168
169 public void setSequenceNextVals(Map<String, Number> sequenceVals) {
170 for (Entry<String, Number> sequence: sequenceVals.entrySet()) {
171 em.createNativeQuery(
172 String.format("alter sequence %s restart with ?", sequence.getKey()))
173 .setParameter(1, sequence.getValue()).executeUpdate();
174 }
175 }
176
177
178
179
180
181
182
183
184 public Map<String, Number> dropAll() throws RuntimeException {
185 try {
186 Map<String, Number> sequenceVals = getSequenceNextVals();
187 for (String script: dropPaths) {
188 executeScript(script);
189 }
190 return sequenceVals;
191 } catch (Exception ex) {
192 throw new RuntimeException("error dropping DB, might not exist: " + ex);
193 }
194 }
195
196
197
198
199
200
201
202
203 public int createAll(Map<String, Number> sequenceVals) throws RuntimeException {
204 int count=0;
205 for (String script: createPaths) {
206 count += executeScript(script);
207 }
208 setSequenceNextVals(sequenceVals);
209 return count;
210 }
211 }