001    /**
002     * Copyright (c) 2000-2013 Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.util;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    
020    import java.io.IOException;
021    
022    import java.lang.ref.WeakReference;
023    import java.lang.reflect.InvocationTargetException;
024    import java.lang.reflect.Method;
025    
026    import java.net.URL;
027    
028    import java.util.ArrayList;
029    import java.util.Collection;
030    import java.util.Collections;
031    import java.util.Enumeration;
032    import java.util.Iterator;
033    import java.util.List;
034    
035    /**
036     * @author Brian Wing Shun Chan
037     * @author Michael C. Han
038     * @author Shuyang Zhou
039     */
040    public class AggregateClassLoader extends ClassLoader {
041    
042            public static ClassLoader getAggregateClassLoader(
043                    ClassLoader parentClassLoader, ClassLoader[] classLoaders) {
044    
045                    if (ArrayUtil.isEmpty(classLoaders)) {
046                            return null;
047                    }
048    
049                    if (classLoaders.length == 1) {
050                            return classLoaders[0];
051                    }
052    
053                    AggregateClassLoader aggregateClassLoader = new AggregateClassLoader(
054                            parentClassLoader);
055    
056                    for (ClassLoader classLoader : classLoaders) {
057                            aggregateClassLoader.addClassLoader(classLoader);
058                    }
059    
060                    return aggregateClassLoader;
061            }
062    
063            public static ClassLoader getAggregateClassLoader(
064                    ClassLoader[] classLoaders) {
065    
066                    if (ArrayUtil.isEmpty(classLoaders)) {
067                            return null;
068                    }
069    
070                    return getAggregateClassLoader(classLoaders[0], classLoaders);
071            }
072    
073            public AggregateClassLoader(ClassLoader classLoader) {
074                    _parentClassLoaderReference = new WeakReference<ClassLoader>(
075                            classLoader);
076            }
077    
078            public void addClassLoader(ClassLoader classLoader) {
079                    List<ClassLoader> classLoaders = getClassLoaders();
080    
081                    if (classLoaders.contains(classLoader)) {
082                            return;
083                    }
084    
085                    if ((classLoader instanceof AggregateClassLoader) &&
086                            classLoader.getParent().equals(getParent())) {
087    
088                            AggregateClassLoader aggregateClassLoader =
089                                    (AggregateClassLoader)classLoader;
090    
091                            for (ClassLoader curClassLoader :
092                                            aggregateClassLoader.getClassLoaders()) {
093    
094                                    addClassLoader(curClassLoader);
095                            }
096                    }
097                    else {
098                            _classLoaderReferences.add(
099                                    new WeakReference<ClassLoader>(classLoader));
100                    }
101            }
102    
103            public void addClassLoader(ClassLoader... classLoaders) {
104                    for (ClassLoader classLoader : classLoaders) {
105                            addClassLoader(classLoader);
106                    }
107            }
108    
109            public void addClassLoader(Collection<ClassLoader> classLoaders) {
110                    for (ClassLoader classLoader : classLoaders) {
111                            addClassLoader(classLoader);
112                    }
113            }
114    
115            @Override
116            public boolean equals(Object obj) {
117                    if (this == obj) {
118                            return true;
119                    }
120    
121                    if (!(obj instanceof AggregateClassLoader)) {
122                            return false;
123                    }
124    
125                    AggregateClassLoader aggregateClassLoader = (AggregateClassLoader)obj;
126    
127                    if (_classLoaderReferences.equals(
128                                    aggregateClassLoader._classLoaderReferences) &&
129                            (((getParent() == null) &&
130                              (aggregateClassLoader.getParent() == null)) ||
131                             ((getParent() != null) &&
132                              getParent().equals(aggregateClassLoader.getParent())))) {
133    
134                            return true;
135                    }
136    
137                    return false;
138            }
139    
140            public List<ClassLoader> getClassLoaders() {
141                    List<ClassLoader> classLoaders = new ArrayList<ClassLoader>(
142                            _classLoaderReferences.size());
143    
144                    Iterator<WeakReference<ClassLoader>> itr =
145                            _classLoaderReferences.iterator();
146    
147                    while (itr.hasNext()) {
148                            WeakReference<ClassLoader> weakReference = itr.next();
149    
150                            ClassLoader classLoader = weakReference.get();
151    
152                            if (classLoader == null) {
153                                    itr.remove();
154                            }
155                            else {
156                                    classLoaders.add(classLoader);
157                            }
158                    }
159    
160                    return classLoaders;
161            }
162    
163            @Override
164            public URL getResource(String name) {
165                    for (ClassLoader classLoader : getClassLoaders()) {
166                            URL url = _getResource(classLoader, name);
167    
168                            if (url != null) {
169                                    return url;
170                            }
171                    }
172    
173                    ClassLoader parentClassLoader = _parentClassLoaderReference.get();
174    
175                    if (parentClassLoader == null) {
176                            return null;
177                    }
178    
179                    return parentClassLoader.getResource(name);
180            }
181    
182            @Override
183            public Enumeration<URL> getResources(String name) throws IOException {
184                    List<URL> urls = new ArrayList<URL>();
185    
186                    for (ClassLoader classLoader : getClassLoaders()) {
187                            urls.addAll(Collections.list(_getResources(classLoader, name)));
188                    }
189    
190                    ClassLoader parentClassLoader = _parentClassLoaderReference.get();
191    
192                    if (parentClassLoader != null) {
193                            urls.addAll(
194                                    Collections.list(_getResources(parentClassLoader, name)));
195                    }
196    
197                    return Collections.enumeration(urls);
198            }
199    
200            @Override
201            public int hashCode() {
202                    if (_classLoaderReferences != null) {
203                            return _classLoaderReferences.hashCode();
204                    }
205                    else {
206                            return 0;
207                    }
208            }
209    
210            @Override
211            protected Class<?> findClass(String name) throws ClassNotFoundException {
212                    for (ClassLoader classLoader : getClassLoaders()) {
213                            try {
214                                    return _findClass(classLoader, name);
215                            }
216                            catch (ClassNotFoundException cnfe) {
217                            }
218                    }
219    
220                    throw new ClassNotFoundException("Unable to find class " + name);
221            }
222    
223            @Override
224            protected synchronized Class<?> loadClass(String name, boolean resolve)
225                    throws ClassNotFoundException {
226    
227                    Class<?> loadedClass = null;
228    
229                    for (ClassLoader classLoader : getClassLoaders()) {
230                            try {
231                                    loadedClass = _loadClass(classLoader, name, resolve);
232    
233                                    break;
234                            }
235                            catch (ClassNotFoundException cnfe) {
236                            }
237                    }
238    
239                    if (loadedClass == null) {
240                            ClassLoader parentClassLoader = _parentClassLoaderReference.get();
241    
242                            if (parentClassLoader == null) {
243                                    throw new ClassNotFoundException(
244                                            "Parent class loader has been garbage collected");
245                            }
246    
247                            loadedClass = _loadClass(parentClassLoader, name, resolve);
248                    }
249                    else if (resolve) {
250                            resolveClass(loadedClass);
251                    }
252    
253                    return loadedClass;
254            }
255    
256            private static Class<?> _findClass(ClassLoader classLoader, String name)
257                    throws ClassNotFoundException {
258    
259                    try {
260                            return (Class<?>) _findClassMethod.invoke(classLoader, name);
261                    }
262                    catch (InvocationTargetException ite) {
263                            throw new ClassNotFoundException(
264                                    "Unable to find class " + name, ite.getTargetException());
265                    }
266                    catch (Exception e) {
267                            throw new ClassNotFoundException("Unable to find class " + name, e);
268                    }
269            }
270    
271            private static URL _getResource(ClassLoader classLoader, String name) {
272                    try {
273                            return (URL)_getResourceMethod.invoke(classLoader, name);
274                    }
275                    catch (InvocationTargetException ite) {
276                            return null;
277                    }
278                    catch (Exception e) {
279                            return null;
280                    }
281            }
282    
283            private static Enumeration<URL> _getResources(
284                            ClassLoader classLoader, String name)
285                    throws IOException {
286    
287                    try {
288                            return (Enumeration<URL>)_getResourcesMethod.invoke(
289                                    classLoader, name);
290                    }
291                    catch (InvocationTargetException ite) {
292                            Throwable t = ite.getTargetException();
293    
294                            throw new IOException(t.getMessage());
295                    }
296                    catch (Exception e) {
297                            throw new IOException(e.getMessage());
298                    }
299            }
300    
301            private static Class<?> _loadClass(
302                            ClassLoader classLoader, String name, boolean resolve)
303                    throws ClassNotFoundException {
304    
305                    try {
306                            return (Class<?>)_loadClassMethod.invoke(
307                                    classLoader, name, resolve);
308                    }
309                    catch (InvocationTargetException ite) {
310                            throw new ClassNotFoundException(
311                                    "Unable to load class " + name, ite.getTargetException());
312                    }
313                    catch (Exception e) {
314                            throw new ClassNotFoundException("Unable to load class " + name, e);
315                    }
316            }
317    
318            private static Log _log = LogFactoryUtil.getLog(AggregateClassLoader.class);
319    
320            private static Method _findClassMethod;
321            private static Method _getResourceMethod;
322            private static Method _getResourcesMethod;
323            private static Method _loadClassMethod;
324    
325            private List<WeakReference<ClassLoader>> _classLoaderReferences =
326                    new ArrayList<WeakReference<ClassLoader>>();
327            private WeakReference<ClassLoader> _parentClassLoaderReference;
328    
329            static {
330                    try {
331                            _findClassMethod = ReflectionUtil.getDeclaredMethod(
332                                    ClassLoader.class, "findClass", String.class);
333                            _getResourceMethod = ReflectionUtil.getDeclaredMethod(
334                                    ClassLoader.class, "getResource", String.class);
335                            _getResourcesMethod = ReflectionUtil.getDeclaredMethod(
336                                    ClassLoader.class, "getResources", String.class);
337                            _loadClassMethod = ReflectionUtil.getDeclaredMethod(
338                                    ClassLoader.class, "loadClass", String.class, boolean.class);
339                    }
340                    catch (Exception e) {
341                            if (_log.isErrorEnabled()) {
342                                    _log.error("Unable to locate required methods", e);
343                            }
344                    }
345            }
346    
347    }