001    /**
002     * Copyright (c) 2000-2013 Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.util;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    
020    import java.lang.reflect.Field;
021    import java.lang.reflect.Modifier;
022    
023    import java.util.HashMap;
024    import java.util.HashSet;
025    import java.util.Map;
026    import java.util.Set;
027    
028    /**
029     * @author Shuyang Zhou
030     */
031    public class DefaultThreadLocalBinder implements ThreadLocalBinder {
032    
033            public void afterPropertiesSet() throws Exception {
034                    if (_threadLocalSources == null) {
035                            throw new IllegalArgumentException("Thread local sources is null");
036                    }
037    
038                    init(getClassLoader());
039            }
040    
041            @Override
042            public void bind() {
043                    Map<ThreadLocal<?>, ?> threadLocalValues = _threadLocalValues.get();
044    
045                    for (Map.Entry<ThreadLocal<?>, ?> entry :
046                                    threadLocalValues.entrySet()) {
047    
048                            ThreadLocal<Object> threadLocal =
049                                    (ThreadLocal<Object>)entry.getKey();
050                            Object value = entry.getValue();
051    
052                            threadLocal.set(value);
053                    }
054            }
055    
056            @Override
057            public void cleanUp() {
058                    for (ThreadLocal<?> threadLocal : _threadLocals) {
059                            threadLocal.remove();
060                    }
061            }
062    
063            public ClassLoader getClassLoader() {
064                    if (_classLoader == null) {
065                            Thread currentThread = Thread.currentThread();
066    
067                            _classLoader = currentThread.getContextClassLoader();
068                    }
069    
070                    return _classLoader;
071            }
072    
073            public void init(ClassLoader classLoader) throws Exception {
074                    for (Map.Entry<String, String> entry : _threadLocalSources.entrySet()) {
075                            String className = entry.getKey();
076                            String fieldName = entry.getValue();
077    
078                            Class<?> clazz = classLoader.loadClass(className);
079    
080                            Field field = ReflectionUtil.getDeclaredField(clazz, fieldName);
081    
082                            if (!ThreadLocal.class.isAssignableFrom(field.getType())) {
083                                    if (_log.isWarnEnabled()) {
084                                            _log.warn(
085                                                    fieldName +
086                                                            " is not of type ThreadLocal. Skip binding.");
087                                    }
088    
089                                    continue;
090                            }
091    
092                            if (!Modifier.isStatic(field.getModifiers())) {
093                                    if (_log.isWarnEnabled()) {
094                                            _log.warn(
095                                                    fieldName +
096                                                            " is not a static ThreadLocal. Skip binding.");
097                                    }
098    
099                                    continue;
100                            }
101    
102                            ThreadLocal<?> threadLocal = (ThreadLocal<?>)field.get(null);
103    
104                            if (threadLocal == null) {
105                                    if (_log.isWarnEnabled()) {
106                                            _log.warn(fieldName + " is not initialized. Skip binding.");
107                                    }
108    
109                                    continue;
110                            }
111    
112                            _threadLocals.add(threadLocal);
113                    }
114            }
115    
116            @Override
117            public void record() {
118                    Map<ThreadLocal<?>, Object> threadLocalValues =
119                            new HashMap<ThreadLocal<?>, Object>();
120    
121                    for (ThreadLocal<?> threadLocal : _threadLocals) {
122                            Object value = threadLocal.get();
123    
124                            threadLocalValues.put(threadLocal, value);
125                    }
126    
127                    _threadLocalValues.set(threadLocalValues);
128            }
129    
130            public void setClassLoader(ClassLoader classLoader) {
131                    _classLoader = classLoader;
132            }
133    
134            public void setThreadLocalSources(Map<String, String> threadLocalSources) {
135                    _threadLocalSources = threadLocalSources;
136            }
137    
138            private static Log _log = LogFactoryUtil.getLog(
139                    DefaultThreadLocalBinder.class);
140    
141            private static ThreadLocal<Map<ThreadLocal<?>, ?>> _threadLocalValues =
142                    new AutoResetThreadLocal<Map<ThreadLocal<?>, ?>>(
143                            DefaultThreadLocalBinder.class + "._threadLocalValueMap") {
144    
145                            @Override
146                            protected Map<ThreadLocal<?>, ?> copy(
147                                    Map<ThreadLocal<?>, ?> threadLocalValueMap) {
148    
149                                    return threadLocalValueMap;
150                            }
151    
152                    };
153    
154            private ClassLoader _classLoader;
155            private Set<ThreadLocal<?>> _threadLocals = new HashSet<ThreadLocal<?>>();
156            private Map<String, String> _threadLocalSources;
157    
158    }