1 package ejava.examples.webtier.web;
2
3 import java.io.IOException;
4 import java.util.HashMap;
5
6 import javax.naming.Context;
7 import javax.naming.NameClassPair;
8 import javax.naming.NamingEnumeration;
9 import javax.naming.NamingException;
10 import javax.persistence.EntityManager;
11 import javax.persistence.EntityManagerFactory;
12 import javax.persistence.EntityTransaction;
13 import javax.persistence.Persistence;
14 import javax.persistence.PersistenceUnit;
15 import javax.servlet.Filter;
16 import javax.servlet.FilterChain;
17 import javax.servlet.FilterConfig;
18 import javax.servlet.ServletException;
19 import javax.servlet.ServletRequest;
20 import javax.servlet.ServletResponse;
21
22 import org.slf4j.Logger;
23 import org.slf4j.LoggerFactory;
24
25 public class JPAFilter implements Filter {
26 private static Logger logger = LoggerFactory.getLogger(JPAFilter.class);
27 private static final String PU_NAME = "webtier";
28 private boolean containerManaged=false;
29 @PersistenceUnit(unitName=PU_NAME)
30 private EntityManagerFactory emf;
31 private static ThreadLocal<EntityManager> em = new ThreadLocal<>();
32
33 public void init(FilterConfig config) throws ServletException {
34 logger.info("filter initializing JPA DAOs, em=f{}", emf);
35 System.out.println(String.format("filter initializing JPA DAOs, em=%s", emf));
36
37 if (emf==null) {
38 emf = Persistence.createEntityManagerFactory(PU_NAME);
39 } else {
40 containerManaged=true;
41 }
42 }
43
44 public void doFilter(ServletRequest request,
45 ServletResponse response,
46 FilterChain chain) throws IOException, ServletException {
47
48 logger.debug("injected entity manager={}", emf);
49 EntityManager entityMgr = initEntityManager();
50
51 EntityTransaction tx = entityMgr.getTransaction();
52 if (!tx.isActive()) {
53 logger.debug("filter: beginning JPA transaction");
54 tx.begin();
55 }
56
57 chain.doFilter(request, response);
58
59 if (tx.isActive()) {
60 if (tx.getRollbackOnly()==true) {
61 logger.debug("filter: rolling back JPA transaction");
62 tx.rollback();
63 }
64 else {
65 logger.debug("filter: committing JPA transaction");
66 tx.commit();
67 }
68 }
69 else {
70 logger.debug("filter: no transaction was active");
71 }
72
73 closeEntityManager();
74 }
75
76 public void destroy() {
77 if (!containerManaged) {
78 emf.close();
79 }
80 }
81
82 private EntityManager initEntityManager() throws ServletException {
83 EntityManager entityMgr = getEntityManager();
84 if (entityMgr==null) {
85 entityMgr = emf.createEntityManager();
86 em.set(entityMgr);
87 }
88 return entityMgr;
89 }
90
91 private void closeEntityManager() {
92 EntityManager entityMgr = getEntityManager();
93 if (entityMgr!=null) {
94 entityMgr.close();
95 em.set(null);
96 }
97 }
98
99 public static final EntityManager getEntityManager() {
100 return em.get();
101 }
102
103 @SuppressWarnings("unused")
104 private void dump(Context context, String name) {
105 StringBuilder text = new StringBuilder();
106 try {
107 doDump(0, text, context, name);
108 }
109 catch (NamingException ex) {}
110 logger.debug(text.toString());
111 }
112
113 private void doDump(int level, StringBuilder text, Context context, String name)
114 throws NamingException {
115 for (NamingEnumeration<NameClassPair> ne = context.list(name); ne.hasMore();) {
116 NameClassPair ncp = (NameClassPair) ne.next();
117 String objectName = ncp.getName();
118 String className = ncp.getClassName();
119 String classText = " :" + className;
120 if (isContext(className)) {
121 text.append(getPad(level) + "+" + objectName + classText +"\n");
122 doDump(level + 1, text, context, name + "/" + objectName);
123 } else {
124 text.append(getPad(level) + "-" + objectName + classText + "\n");
125 }
126 }
127 }
128
129 protected boolean isContext(String className) {
130 try {
131 Class<?> objectClass = Thread.currentThread().getContextClassLoader()
132 .loadClass(className);
133 return Context.class.isAssignableFrom(objectClass);
134 }
135 catch (ClassNotFoundException ex) {
136
137 return false;
138 }
139 }
140
141 protected String getPad(int level) {
142 StringBuffer pad = new StringBuffer();
143 for (int i = 0; i < level; i++) {
144 pad.append(" ");
145 }
146 return pad.toString();
147 }
148
149
150 }