001    /**
002     * Copyright (c) 2000-2013 Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.util;
016    
017    import java.util.HashMap;
018    import java.util.HashSet;
019    import java.util.Map;
020    import java.util.Set;
021    import java.util.concurrent.atomic.AtomicInteger;
022    
023    /**
024     * @author Shuyang Zhou
025     */
026    public class CentralizedThreadLocal<T> extends ThreadLocal<T> {
027    
028            public static void clearLongLivedThreadLocals() {
029                    _longLivedThreadLocals.remove();
030            }
031    
032            public static void clearShortLivedThreadLocals() {
033                    _shortLivedThreadLocals.remove();
034            }
035    
036            public static Map<CentralizedThreadLocal<?>, Object>
037                    getLongLivedThreadLocals() {
038    
039                    return _toMap(_longLivedThreadLocals.get());
040            }
041    
042            public static Map<CentralizedThreadLocal<?>, Object>
043                    getShortLivedThreadLocals() {
044    
045                    return _toMap(_shortLivedThreadLocals.get());
046            }
047    
048            public static void setThreadLocals(
049                    Map<CentralizedThreadLocal<?>, Object> longLivedThreadLocals,
050                    Map<CentralizedThreadLocal<?>, Object> shortLivedThreadLocals) {
051    
052                    ThreadLocalMap threadLocalMap = _longLivedThreadLocals.get();
053    
054                    for (Map.Entry<CentralizedThreadLocal<?>, Object> entry :
055                                    longLivedThreadLocals.entrySet()) {
056    
057                            threadLocalMap.putEntry(entry.getKey(), entry.getValue());
058                    }
059    
060                    threadLocalMap = _shortLivedThreadLocals.get();
061    
062                    for (Map.Entry<CentralizedThreadLocal<?>, Object> entry :
063                                    shortLivedThreadLocals.entrySet()) {
064    
065                            threadLocalMap.putEntry(entry.getKey(), entry.getValue());
066                    }
067            }
068    
069            public CentralizedThreadLocal(boolean shortLived) {
070                    _shortLived = shortLived;
071    
072                    if (shortLived) {
073                            _hashCode = _shortLivedNextHasCode.getAndAdd(_HASH_INCREMENT);
074                    }
075                    else {
076                            _hashCode = _longLivedNextHasCode.getAndAdd(_HASH_INCREMENT);
077                    }
078            }
079    
080            @Override
081            public T get() {
082                    ThreadLocalMap threadLocalMap = _getThreadLocalMap();
083    
084                    Entry entry = threadLocalMap.getEntry(this);
085    
086                    if (entry == null) {
087                            T value = initialValue();
088    
089                            threadLocalMap.putEntry(this, value);
090    
091                            return value;
092                    }
093                    else {
094                            return (T)entry._value;
095                    }
096            }
097    
098            @Override
099            public int hashCode() {
100                    return _hashCode;
101            }
102    
103            @Override
104            public void remove() {
105                    ThreadLocalMap threadLocalMap = _getThreadLocalMap();
106    
107                    threadLocalMap.removeEntry(this);
108            }
109    
110            @Override
111            public void set(T value) {
112                    ThreadLocalMap threadLocalMap = _getThreadLocalMap();
113    
114                    threadLocalMap.putEntry(this, value);
115            }
116    
117            protected T copy(T value) {
118                    if (value != null) {
119                            Class<?> clazz = value.getClass();
120    
121                            if (_immutableTypes.contains(clazz)) {
122                                    return value;
123                            }
124                    }
125    
126                    return null;
127            }
128    
129            private static Map<CentralizedThreadLocal<?>, Object> _toMap(
130                    ThreadLocalMap threadLocalMap) {
131    
132                    Map<CentralizedThreadLocal<?>, Object> map =
133                            new HashMap<CentralizedThreadLocal<?>, Object>(
134                                    threadLocalMap._table.length);
135    
136                    for (Entry entry : threadLocalMap._table) {
137                            if (entry != null) {
138                                    CentralizedThreadLocal<Object> centralizedThreadLocal =
139                                            (CentralizedThreadLocal<Object>)entry._key;
140    
141                                    Object value = centralizedThreadLocal.copy(entry._value);
142    
143                                    if (value != null) {
144                                            map.put(centralizedThreadLocal, value);
145                                    }
146                            }
147                    }
148    
149                    return map;
150            }
151    
152            private ThreadLocalMap _getThreadLocalMap() {
153                    if (_shortLived) {
154                            return _shortLivedThreadLocals.get();
155                    }
156                    else {
157                            return _longLivedThreadLocals.get();
158                    }
159            }
160    
161            private static final int _HASH_INCREMENT = 0x61c88647;
162    
163            private static final Set<Class<?>> _immutableTypes =
164                    new HashSet<Class<?>>();
165    
166            static {
167                    _immutableTypes.add(Boolean.class);
168                    _immutableTypes.add(Byte.class);
169                    _immutableTypes.add(Character.class);
170                    _immutableTypes.add(Short.class);
171                    _immutableTypes.add(Integer.class);
172                    _immutableTypes.add(Long.class);
173                    _immutableTypes.add(Float.class);
174                    _immutableTypes.add(Double.class);
175                    _immutableTypes.add(String.class);
176            }
177    
178            private static final AtomicInteger _longLivedNextHasCode =
179                    new AtomicInteger();
180            private static final ThreadLocal<ThreadLocalMap> _longLivedThreadLocals =
181                    new ThreadLocalMapThreadLocal();
182            private static final AtomicInteger _shortLivedNextHasCode =
183                    new AtomicInteger();
184            private static final ThreadLocal<ThreadLocalMap> _shortLivedThreadLocals =
185                    new ThreadLocalMapThreadLocal();
186    
187            private final int _hashCode;
188            private final boolean _shortLived;
189    
190            private static class Entry {
191    
192                    public Entry(CentralizedThreadLocal<?> key, Object value, Entry next) {
193                            _key = key;
194                            _value = value;
195                            _next = next;
196                    }
197    
198                    private CentralizedThreadLocal<?> _key;
199                    private Entry _next;
200                    private Object _value;
201    
202            }
203    
204            private static class ThreadLocalMap {
205    
206                    public void expand(int newCapacity) {
207                            if (_table.length == _MAXIMUM_CAPACITY) {
208                                    _threshold = Integer.MAX_VALUE;
209    
210                                    return;
211                            }
212    
213                            Entry[] newTable = new Entry[newCapacity];
214    
215                            for (int i = 0; i < _table.length; i++) {
216                                    Entry entry = _table[i];
217    
218                                    if (entry == null) {
219                                            continue;
220                                    }
221    
222                                    _table[i] = null;
223    
224                                    do {
225                                            Entry nextEntry = entry._next;
226    
227                                            int index = entry._key._hashCode & (newCapacity - 1);
228    
229                                            entry._next = newTable[index];
230    
231                                            newTable[index] = entry;
232    
233                                            entry = nextEntry;
234                                    }
235                                    while (entry != null);
236                            }
237    
238                            _table = newTable;
239    
240                            _threshold = newCapacity * 2 / 3;
241                    }
242    
243                    public Entry getEntry(CentralizedThreadLocal<?> key) {
244                            int index = key._hashCode & (_table.length - 1);
245    
246                            Entry entry = _table[index];
247    
248                            if (entry == null) {
249                                    return null;
250                            }
251    
252                            if (entry._key == key) {
253                                    return entry;
254                            }
255    
256                            while ((entry = entry._next) != null) {
257                                    if (entry._key == key) {
258                                            return entry;
259                                    }
260                            }
261    
262                            return null;
263                    }
264    
265                    public void putEntry(CentralizedThreadLocal<?> key, Object value) {
266                            int index = key._hashCode & (_table.length - 1);
267    
268                            for (Entry entry = _table[index]; entry != null;
269                                    entry = entry._next) {
270    
271                                    if (entry._key == key) {
272                                            entry._value = value;
273    
274                                            return;
275                                    }
276                            }
277    
278                            _table[index] = new Entry(key, value, _table[index]);
279    
280                            if (_size++ >= _threshold) {
281                                    expand(2 * _table.length);
282                            }
283                    }
284    
285                    public void removeEntry(CentralizedThreadLocal<?> key) {
286                            int index = key._hashCode & (_table.length - 1);
287    
288                            Entry previousEntry = null;
289    
290                            Entry entry = _table[index];
291    
292                            while (entry != null) {
293                                    Entry nextEntry = entry._next;
294    
295                                    if (entry._key == key) {
296                                            _size--;
297    
298                                            if (previousEntry == null) {
299                                                    _table[index] = nextEntry;
300                                            }
301                                            else {
302                                                    previousEntry._next = nextEntry;
303                                            }
304    
305                                            return;
306                                    }
307    
308                                    previousEntry = entry;
309                                    entry = nextEntry;
310                            }
311                    }
312    
313                    private static final int _INITIAL_CAPACITY = 16;
314    
315                    private static final int _MAXIMUM_CAPACITY = 1 << 30;
316    
317                    private int _size;
318                    private Entry[] _table = new Entry[_INITIAL_CAPACITY];
319                    private int _threshold = _INITIAL_CAPACITY * 2 / 3;
320    
321            }
322    
323            private static class ThreadLocalMapThreadLocal
324                    extends ThreadLocal<ThreadLocalMap> {
325    
326                    @Override
327                    protected ThreadLocalMap initialValue() {
328                            return new ThreadLocalMap();
329                    }
330    
331            }
332    
333    }