001    /**
002     * Copyright (c) 2000-2013 Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.util;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    
020    import java.lang.ref.Reference;
021    import java.lang.reflect.Field;
022    import java.lang.reflect.Method;
023    
024    /**
025     * @author Tina Tian
026     */
027    public class ClearThreadLocalUtil {
028    
029            public static void clearThreadLocal() throws Exception {
030                    if (!_initialized) {
031                            return;
032                    }
033    
034                    Thread[] threads = ThreadUtil.getThreads();
035    
036                    Thread currentThread = Thread.currentThread();
037    
038                    ClassLoader contextClassLoader = currentThread.getContextClassLoader();
039    
040                    for (Thread thread : threads) {
041                            _clearThreadLocal(thread, contextClassLoader);
042                    }
043            }
044    
045            private static void _clearThreadLocal(
046                            Thread thread, ClassLoader classLoader)
047                    throws Exception {
048    
049                    if (thread == null) {
050                            return;
051                    }
052    
053                    Object threadLocalMap = _threadLocalsField.get(thread);
054    
055                    Object inheritableThreadLocalMap = _inheritableThreadLocalsField.get(
056                            thread);
057    
058                    _clearThreadLocalMap(threadLocalMap, classLoader);
059                    _clearThreadLocalMap(inheritableThreadLocalMap, classLoader);
060            }
061    
062            private static void _clearThreadLocalMap(
063                            Object threadLocalMap, ClassLoader classLoader)
064                    throws Exception {
065    
066                    if (threadLocalMap == null) {
067                            return;
068                    }
069    
070                    Object[] table = (Object[])_tableField.get(threadLocalMap);
071    
072                    if (table == null) {
073                            return;
074                    }
075    
076                    int staleEntriesCount = 0;
077    
078                    for (Object tableEntry : table) {
079                            if (tableEntry == null) {
080                                    continue;
081                            }
082    
083                            Object key = ((Reference<?>)tableEntry).get();
084                            Object value = _valueField.get(tableEntry);
085    
086                            boolean remove = false;
087    
088                            if (key != null) {
089                                    Class<?> keyClass = key.getClass();
090    
091                                    ClassLoader keyClassLoader = keyClass.getClassLoader();
092    
093                                    if (keyClassLoader == classLoader) {
094                                            remove = true;
095                                    }
096                            }
097    
098                            if (value != null) {
099                                    Class<?> valueClass = value.getClass();
100    
101                                    ClassLoader valueClassLoader = valueClass.getClassLoader();
102    
103                                    if (valueClassLoader == classLoader) {
104                                            remove = true;
105                                    }
106                            }
107    
108                            if (remove) {
109                                    if (key != null) {
110                                            if (_log.isDebugEnabled()) {
111                                                    Class<?> keyClass = key.getClass();
112    
113                                                    _log.debug(
114                                                            "Clear a ThreadLocal with key of type " +
115                                                                    keyClass.getCanonicalName());
116                                            }
117    
118                                            _removeMethod.invoke(threadLocalMap, key);
119                                    }
120                                    else {
121                                            staleEntriesCount++;
122                                    }
123                            }
124                    }
125    
126                    if (staleEntriesCount > 0) {
127                            _expungeStaleEntriesMethod.invoke(threadLocalMap);
128                    }
129            }
130    
131            private static Log _log = LogFactoryUtil.getLog(ClearThreadLocalUtil.class);
132    
133            private static Method _expungeStaleEntriesMethod;
134            private static Field _inheritableThreadLocalsField;
135            private static boolean _initialized;
136            private static Method _removeMethod;
137            private static Field _tableField;
138            private static Field _threadLocalsField;
139            private static Field _valueField;
140    
141            static {
142                    try {
143                            _inheritableThreadLocalsField = ReflectionUtil.getDeclaredField(
144                                    Thread.class, "inheritableThreadLocals");
145                            _threadLocalsField = ReflectionUtil.getDeclaredField(
146                                    Thread.class, "threadLocals");
147    
148                            Class<?> threadLocalMapClass = Class.forName(
149                                    "java.lang.ThreadLocal$ThreadLocalMap");
150    
151                            _expungeStaleEntriesMethod = ReflectionUtil.getDeclaredMethod(
152                                    threadLocalMapClass, "expungeStaleEntries");
153                            _removeMethod = ReflectionUtil.getDeclaredMethod(
154                                    threadLocalMapClass, "remove", ThreadLocal.class);
155                            _tableField = ReflectionUtil.getDeclaredField(
156                                    threadLocalMapClass, "table");
157    
158                            Class<?> threadLocalMapEntryClass = Class.forName(
159                                    "java.lang.ThreadLocal$ThreadLocalMap$Entry");
160    
161                            _valueField = ReflectionUtil.getDeclaredField(
162                                    threadLocalMapEntryClass, "value");
163    
164                            _initialized = true;
165                    }
166                    catch (Throwable t) {
167                            _initialized = false;
168    
169                            if (_log.isWarnEnabled()) {
170                                    _log.warn("Failed to initialize ClearThreadLocalUtil", t);
171                            }
172                    }
173            }
174    
175    }