001    /**
002     * Copyright (c) 2000-2013 Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.kernel.test;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    import com.liferay.portal.kernel.process.ClassPathUtil;
020    import com.liferay.portal.kernel.process.ProcessCallable;
021    import com.liferay.portal.kernel.process.ProcessException;
022    import com.liferay.portal.kernel.process.ProcessExecutor;
023    import com.liferay.portal.kernel.util.MethodHandler;
024    import com.liferay.portal.kernel.util.MethodKey;
025    import com.liferay.portal.kernel.util.PortalClassLoaderUtil;
026    import com.liferay.portal.kernel.util.StringBundler;
027    import com.liferay.portal.kernel.util.StringPool;
028    
029    import java.io.IOException;
030    import java.io.InputStream;
031    import java.io.OutputStream;
032    import java.io.Serializable;
033    
034    import java.lang.reflect.InvocationTargetException;
035    
036    import java.net.InetAddress;
037    import java.net.InetSocketAddress;
038    import java.net.ServerSocket;
039    import java.net.Socket;
040    
041    import java.util.ArrayList;
042    import java.util.List;
043    
044    import org.junit.After;
045    import org.junit.Before;
046    import org.junit.runner.manipulation.Sorter;
047    import org.junit.runners.BlockJUnit4ClassRunner;
048    import org.junit.runners.model.FrameworkMethod;
049    import org.junit.runners.model.InitializationError;
050    import org.junit.runners.model.Statement;
051    import org.junit.runners.model.TestClass;
052    
053    /**
054     * @author Shuyang Zhou
055     */
056    public class NewJVMJUnitTestRunner extends BlockJUnit4ClassRunner {
057    
058            public NewJVMJUnitTestRunner(Class<?> clazz) throws InitializationError {
059                    super(clazz);
060    
061                    _classPath = ClassPathUtil.getJVMClassPath(false);
062    
063                    sort(new Sorter(new DescriptionComparator()));
064            }
065    
066            protected List<String> createArguments(FrameworkMethod frameworkMethod) {
067                    List<String> arguments = new ArrayList<String>();
068    
069                    boolean junitDebug = Boolean.getBoolean("junit.debug");
070    
071                    if (junitDebug) {
072                            arguments.add(_JPDA_OPTIONS);
073                    }
074    
075                    arguments.add("-Djava.net.preferIPv4Stack=true");
076    
077                    String fileName = System.getProperty(
078                            "net.sourceforge.cobertura.datafile");
079    
080                    if (fileName != null) {
081                            arguments.add("-Dnet.sourceforge.cobertura.datafile=" + fileName);
082                    }
083    
084                    return arguments;
085            }
086    
087            protected ServerSocket createServerSocket() {
088                    int port = _START_SERVER_PORT;
089    
090                    while (true) {
091                            try {
092                                    ServerSocket serverSocket = new ServerSocket();
093    
094                                    serverSocket.setReuseAddress(true);
095    
096                                    serverSocket.bind(
097                                            new InetSocketAddress(InetAddress.getLocalHost(), port));
098    
099                                    return serverSocket;
100                            }
101                            catch (IOException ioe) {
102                                    port++;
103                            }
104                    }
105            }
106    
107            @Override
108            protected Statement methodBlock(FrameworkMethod frameworkMethod) {
109                    Thread currentThread = Thread.currentThread();
110    
111                    ClassLoader contextClassLoader = currentThread.getContextClassLoader();
112    
113                    PortalClassLoaderUtil.setClassLoader(contextClassLoader);
114    
115                    TestClass testClass = getTestClass();
116    
117                    List<FrameworkMethod> beforeFrameworkMethods =
118                            testClass.getAnnotatedMethods(Before.class);
119    
120                    List<FrameworkMethod> afterFrameworkMethods =
121                            testClass.getAnnotatedMethods(After.class);
122    
123                    List<String> arguments = createArguments(frameworkMethod);
124    
125                    Class<?> clazz = testClass.getJavaClass();
126    
127                    return new RunInNewJVMStatment(
128                            _classPath, arguments, clazz, beforeFrameworkMethods,
129                            frameworkMethod, afterFrameworkMethods);
130            }
131    
132            protected ProcessCallable<Serializable> processProcessCallable(
133                    ProcessCallable<Serializable> processCallable,
134                    MethodKey testMethodKey) {
135    
136                    return processCallable;
137            }
138    
139            private static final int _HEARTBEAT_MAGIC_MUNBER = 253;
140    
141            private static final String _JPDA_OPTIONS =
142                    "-agentlib:jdwp=transport=dt_socket,address=8001,server=y,suspend=y";
143    
144            private static final int _START_SERVER_PORT = 10234;
145    
146            private static Log _log = LogFactoryUtil.getLog(
147                    NewJVMJUnitTestRunner.class);
148    
149            private String _classPath;
150    
151            private static class HeartbeatClientThread extends Thread {
152    
153                    public HeartbeatClientThread(String name, int serverPort) {
154                            _serverPort = serverPort;
155    
156                            setDaemon(true);
157                            setName(
158                                    HeartbeatClientThread.class.getSimpleName().concat(
159                                            StringPool.POUND).concat(name));
160                    }
161    
162                    @Override
163                    public void run() {
164                            Socket socket = null;
165    
166                            try {
167                                    socket = new Socket(InetAddress.getLocalHost(), _serverPort);
168    
169                                    socket.shutdownInput();
170    
171                                    OutputStream outputStream = null;
172    
173                                    try {
174                                            outputStream = socket.getOutputStream();
175                                    }
176                                    catch (IOException ioe) {
177                                            return;
178                                    }
179    
180                                    try {
181                                            while (!_stop) {
182                                                    outputStream.write(_HEARTBEAT_MAGIC_MUNBER);
183    
184                                                    try {
185                                                            sleep(1000);
186                                                    }
187                                                    catch (InterruptedException ie) {
188                                                    }
189                                            }
190                                    }
191                                    catch (IOException ioe) {
192                                            _log.error(
193                                                    "Main process socket peer closed unexpectedly", ioe);
194    
195                                            System.exit(10);
196                                    }
197                            }
198                            catch (Exception e) {
199                                    _log.error(e, e);
200                            }
201                            finally {
202                                    try {
203                                            socket.close();
204                                    }
205                                    catch (IOException ioe) {
206                                            _log.error(ioe, ioe);
207                                    }
208                            }
209                    }
210    
211                    public void shutdown() {
212                            _stop = true;
213                            interrupt();
214                    }
215    
216                    private static Log _log = LogFactoryUtil.getLog(
217                            HeartbeatClientThread.class);
218    
219                    private final int _serverPort;
220                    private volatile boolean _stop;
221    
222            }
223    
224            private static class HeartbeatServerThread extends Thread {
225    
226                    public HeartbeatServerThread(String name, ServerSocket serverSocket) {
227                            _serverSocket = serverSocket;
228    
229                            setDaemon(true);
230                            setName(
231                                    HeartbeatServerThread.class.getSimpleName().concat(
232                                            StringPool.POUND).concat(name));
233                    }
234    
235                    @Override
236                    public void run() {
237                            try {
238                                    _socket = _serverSocket.accept();
239    
240                                    _serverSocket.close();
241    
242                                    _socket.shutdownOutput();
243    
244                                    InputStream inputStream = null;
245    
246                                    try {
247                                            inputStream = _socket.getInputStream();
248                                    }
249                                    catch (IOException ioe) {
250                                            return;
251                                    }
252    
253                                    int result = -1;
254    
255                                    while ((result = inputStream.read()) != -1) {
256                                            if (result != _HEARTBEAT_MAGIC_MUNBER) {
257                                                    inputStream.close();
258    
259                                                    _socket.close();
260                                                    _socket = null;
261    
262                                                    break;
263                                            }
264                                    }
265                            }
266                            catch (IOException ioe) {
267                                    _log.error(ioe, ioe);
268                            }
269                            finally {
270                                    try {
271                                            _serverSocket.close();
272                                    }
273                                    catch (IOException ioe) {
274                                            _log.error(ioe, ioe);
275                                    }
276    
277                                    if (_socket != null) {
278                                            try {
279                                                    _socket.close();
280                                                    _socket = null;
281                                            }
282                                            catch (IOException ioe) {
283                                                    _log.error(ioe, ioe);
284                                            }
285                                    }
286                            }
287                    }
288    
289                    public void shutdown() {
290                            interrupt();
291    
292                            if (_socket != null) {
293                                    try {
294                                            _socket.close();
295                                    }
296                                    catch (IOException ioe) {
297                                            _log.error(ioe, ioe);
298                                    }
299                            }
300                    }
301    
302                    private final ServerSocket _serverSocket;
303                    private volatile Socket _socket;
304    
305            }
306    
307            private static class TestProcessCallable
308                    implements ProcessCallable<Serializable> {
309    
310                    public TestProcessCallable(
311                            String testClassName, List<MethodKey> beforeMethodKeys,
312                            MethodKey testMethodKey, List<MethodKey> afterMethodKeys,
313                            int serverPort) {
314    
315                            _testClassName = testClassName;
316                            _beforeMethodKeys = beforeMethodKeys;
317                            _testMethodKey = testMethodKey;
318                            _afterMethodKeys = afterMethodKeys;
319                            _serverPort = serverPort;
320                    }
321    
322                    public Serializable call() throws ProcessException {
323                            final HeartbeatClientThread heartbeatClientThread =
324                                    new HeartbeatClientThread(toString(), _serverPort);
325    
326                            Runtime runtime = Runtime.getRuntime();
327    
328                            runtime.addShutdownHook(
329                                    new Thread() {
330    
331                                            @Override
332                                            public void run() {
333                                                    heartbeatClientThread.shutdown();
334                                            }
335    
336                                    }
337                            );
338    
339                            heartbeatClientThread.start();
340    
341                            Thread currentThread = Thread.currentThread();
342    
343                            ClassLoader contextClassLoader =
344                                    currentThread.getContextClassLoader();
345    
346                            try {
347                                    Class<?> clazz = contextClassLoader.loadClass(_testClassName);
348    
349                                    Object object = clazz.newInstance();
350    
351                                    for (MethodKey beforeMethodKey : _beforeMethodKeys) {
352                                            _invoke(beforeMethodKey, object);
353                                    }
354    
355                                    _invoke(_testMethodKey, object);
356    
357                                    for (MethodKey afterMethodKey : _afterMethodKeys) {
358                                            _invoke(afterMethodKey, object);
359                                    }
360                            }
361                            catch (Exception e) {
362                                    throw new ProcessException(e);
363                            }
364    
365                            return StringPool.BLANK;
366                    }
367    
368                    @Override
369                    public String toString() {
370                            StringBundler sb = new StringBundler(4);
371    
372                            sb.append(_testClassName);
373                            sb.append(StringPool.PERIOD);
374                            sb.append(_testMethodKey.getMethodName());
375                            sb.append("()");
376    
377                            return sb.toString();
378                    }
379    
380                    private void _invoke(MethodKey methodKey, Object object)
381                            throws Exception {
382    
383                            MethodHandler methodHandler = new MethodHandler(methodKey);
384    
385                            methodHandler.invoke(object);
386                    }
387    
388                    private static final long serialVersionUID = 1L;
389    
390                    private List<MethodKey> _afterMethodKeys;
391                    private List<MethodKey> _beforeMethodKeys;
392                    private int _serverPort;
393                    private String _testClassName;
394                    private MethodKey _testMethodKey;
395    
396            }
397    
398            private class RunInNewJVMStatment extends Statement {
399    
400                    public RunInNewJVMStatment(
401                            String classPath, List<String> arguments, Class<?> testClass,
402                            List<FrameworkMethod> beforeFrameworkMethods,
403                            FrameworkMethod testFrameworkMethod,
404                            List<FrameworkMethod> afterFrameworkMethods) {
405    
406                            _classPath = classPath;
407                            _arguments = arguments;
408                            _testClassName = testClass.getName();
409    
410                            _beforeMethodKeys = new ArrayList<MethodKey>(
411                                    beforeFrameworkMethods.size());
412    
413                            for (FrameworkMethod frameworkMethod : beforeFrameworkMethods) {
414                                    _beforeMethodKeys.add(
415                                            new MethodKey(frameworkMethod.getMethod()));
416                            }
417    
418                            _testMethodKey = new MethodKey(testFrameworkMethod.getMethod());
419    
420                            _afterMethodKeys = new ArrayList<MethodKey>(
421                                    afterFrameworkMethods.size());
422    
423                            for (FrameworkMethod frameworkMethod : afterFrameworkMethods) {
424                                    _afterMethodKeys.add(
425                                            new MethodKey(frameworkMethod.getMethod()));
426                            }
427                    }
428    
429                    @Override
430                    public void evaluate() throws Throwable {
431                            ServerSocket serverSocket = createServerSocket();
432    
433                            ProcessCallable<Serializable> processCallable =
434                                    new TestProcessCallable(
435                                            _testClassName, _beforeMethodKeys, _testMethodKey,
436                                            _afterMethodKeys, serverSocket.getLocalPort());
437    
438                            HeartbeatServerThread heartbeatServerThread =
439                                    new HeartbeatServerThread(
440                                            processCallable.toString(), serverSocket);
441    
442                            heartbeatServerThread.start();
443    
444                            processCallable = processProcessCallable(
445                                    processCallable, _testMethodKey);
446    
447                            try {
448                                    ProcessExecutor.execute(
449                                            processCallable, _classPath, _arguments);
450                            }
451                            catch (ProcessException pe) {
452                                    Throwable cause = pe.getCause();
453    
454                                    while ((cause instanceof ProcessException) ||
455                                               (cause instanceof InvocationTargetException)) {
456    
457                                            cause = cause.getCause();
458                                    }
459    
460                                    throw cause;
461                            }
462                            finally {
463                                    heartbeatServerThread.shutdown();
464                            }
465                    }
466    
467                    private List<MethodKey> _afterMethodKeys;
468                    private List<String> _arguments;
469                    private List<MethodKey> _beforeMethodKeys;
470                    private String _classPath;
471                    private String _testClassName;
472                    private MethodKey _testMethodKey;
473    
474            }
475    
476    }