001    /**
002     * Copyright (c) 2000-2013 Liferay, Inc. All rights reserved.
003     *
004     * This library is free software; you can redistribute it and/or modify it under
005     * the terms of the GNU Lesser General Public License as published by the Free
006     * Software Foundation; either version 2.1 of the License, or (at your option)
007     * any later version.
008     *
009     * This library is distributed in the hope that it will be useful, but WITHOUT
010     * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
011     * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
012     * details.
013     */
014    
015    package com.liferay.portal.servlet.filters.secure;
016    
017    import com.liferay.portal.kernel.log.Log;
018    import com.liferay.portal.kernel.log.LogFactoryUtil;
019    import com.liferay.portal.kernel.servlet.HttpHeaders;
020    import com.liferay.portal.kernel.servlet.ProtectedServletRequest;
021    import com.liferay.portal.kernel.util.GetterUtil;
022    import com.liferay.portal.kernel.util.Http;
023    import com.liferay.portal.kernel.util.HttpUtil;
024    import com.liferay.portal.kernel.util.StringBundler;
025    import com.liferay.portal.kernel.util.StringPool;
026    import com.liferay.portal.kernel.util.StringUtil;
027    import com.liferay.portal.kernel.util.Validator;
028    import com.liferay.portal.model.User;
029    import com.liferay.portal.security.auth.CompanyThreadLocal;
030    import com.liferay.portal.security.auth.PrincipalThreadLocal;
031    import com.liferay.portal.security.permission.PermissionChecker;
032    import com.liferay.portal.security.permission.PermissionCheckerFactoryUtil;
033    import com.liferay.portal.security.permission.PermissionThreadLocal;
034    import com.liferay.portal.service.UserLocalServiceUtil;
035    import com.liferay.portal.servlet.filters.BasePortalFilter;
036    import com.liferay.portal.util.Portal;
037    import com.liferay.portal.util.PortalInstances;
038    import com.liferay.portal.util.PortalUtil;
039    import com.liferay.portal.util.PropsUtil;
040    import com.liferay.portal.util.WebKeys;
041    
042    import java.util.HashSet;
043    import java.util.Set;
044    
045    import javax.servlet.FilterChain;
046    import javax.servlet.FilterConfig;
047    import javax.servlet.http.HttpServletRequest;
048    import javax.servlet.http.HttpServletResponse;
049    import javax.servlet.http.HttpSession;
050    
051    /**
052     * @author Brian Wing Shun Chan
053     * @author Raymond Aug??
054     * @author Alexander Chow
055     */
056    public class SecureFilter extends BasePortalFilter {
057    
058            @Override
059            public void init(FilterConfig filterConfig) {
060                    super.init(filterConfig);
061    
062                    _basicAuthEnabled = GetterUtil.getBoolean(
063                            filterConfig.getInitParameter("basic_auth"));
064                    _digestAuthEnabled = GetterUtil.getBoolean(
065                            filterConfig.getInitParameter("digest_auth"));
066                    _usePermissionChecker = GetterUtil.getBoolean(
067                            filterConfig.getInitParameter("use_permission_checker"));
068    
069                    String propertyPrefix = filterConfig.getInitParameter(
070                            "portal_property_prefix");
071    
072                    String[] hostsAllowedArray = null;
073    
074                    if (Validator.isNull(propertyPrefix)) {
075                            hostsAllowedArray = StringUtil.split(
076                                    filterConfig.getInitParameter("hosts.allowed"));
077                            _httpsRequired = GetterUtil.getBoolean(
078                                    filterConfig.getInitParameter("https.required"));
079                    }
080                    else {
081                            hostsAllowedArray = PropsUtil.getArray(
082                                    propertyPrefix + "hosts.allowed");
083                            _httpsRequired = GetterUtil.getBoolean(
084                                    PropsUtil.get(propertyPrefix + "https.required"));
085                    }
086    
087                    for (int i = 0; i < hostsAllowedArray.length; i++) {
088                            _hostsAllowed.add(hostsAllowedArray[i]);
089                    }
090            }
091    
092            protected HttpServletRequest basicAuth(
093                            HttpServletRequest request, HttpServletResponse response)
094                    throws Exception {
095    
096                    HttpSession session = request.getSession();
097    
098                    session.setAttribute(WebKeys.BASIC_AUTH_ENABLED, Boolean.TRUE);
099    
100                    long userId = GetterUtil.getLong(
101                            (String)session.getAttribute(_AUTHENTICATED_USER));
102    
103                    if (userId > 0) {
104                            request = new ProtectedServletRequest(
105                                    request, String.valueOf(userId), HttpServletRequest.BASIC_AUTH);
106    
107                            initThreadLocals(request);
108                    }
109                    else {
110                            try {
111                                    userId = PortalUtil.getBasicAuthUserId(request);
112                            }
113                            catch (Exception e) {
114                                    _log.error(e, e);
115                            }
116    
117                            if (userId > 0) {
118                                    request = setCredentials(
119                                            request, session, userId, HttpServletRequest.BASIC_AUTH);
120                            }
121                            else {
122                                    response.setHeader(HttpHeaders.WWW_AUTHENTICATE, _BASIC_REALM);
123                                    response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
124    
125                                    return null;
126                            }
127                    }
128    
129                    return request;
130            }
131    
132            protected HttpServletRequest digestAuth(
133                            HttpServletRequest request, HttpServletResponse response)
134                    throws Exception {
135    
136                    HttpSession session = request.getSession();
137    
138                    long userId = GetterUtil.getLong(
139                            (String)session.getAttribute(_AUTHENTICATED_USER));
140    
141                    if (userId > 0) {
142                            request = new ProtectedServletRequest(
143                                    request, String.valueOf(userId),
144                                    HttpServletRequest.DIGEST_AUTH);
145    
146                            initThreadLocals(request);
147                    }
148                    else {
149                            try {
150                                    userId = PortalUtil.getDigestAuthUserId(request);
151                            }
152                            catch (Exception e) {
153                                    _log.error(e, e);
154                            }
155    
156                            if (userId > 0) {
157                                    request = setCredentials(
158                                            request, session, userId, HttpServletRequest.DIGEST_AUTH);
159                            }
160                            else {
161    
162                                    // Must generate a new nonce for each 401 (RFC2617, 3.2.1)
163    
164                                    long companyId = PortalInstances.getCompanyId(request);
165    
166                                    String remoteAddress = request.getRemoteAddr();
167    
168                                    String nonce = NonceUtil.generate(companyId, remoteAddress);
169    
170                                    StringBundler sb = new StringBundler(4);
171    
172                                    sb.append(_DIGEST_REALM);
173                                    sb.append(", nonce=\"");
174                                    sb.append(nonce);
175                                    sb.append("\"");
176    
177                                    response.setHeader(HttpHeaders.WWW_AUTHENTICATE, sb.toString());
178                                    response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
179    
180                                    return null;
181                            }
182                    }
183    
184                    return request;
185            }
186    
187            protected void initThreadLocals(HttpServletRequest request)
188                    throws Exception {
189    
190                    HttpSession session = request.getSession();
191    
192                    User user = (User)session.getAttribute(WebKeys.USER);
193    
194                    CompanyThreadLocal.setCompanyId(user.getCompanyId());
195    
196                    PrincipalThreadLocal.setName(user.getUserId());
197                    PrincipalThreadLocal.setPassword(PortalUtil.getUserPassword(request));
198    
199                    if (!_usePermissionChecker) {
200                            return;
201                    }
202    
203                    PermissionChecker permissionChecker =
204                            PermissionCheckerFactoryUtil.create(user);
205    
206                    PermissionThreadLocal.setPermissionChecker(permissionChecker);
207            }
208    
209            protected boolean isAccessAllowed(HttpServletRequest request) {
210                    if (_hostsAllowed.isEmpty()) {
211                            return true;
212                    }
213    
214                    String remoteAddr = request.getRemoteAddr();
215    
216                    if (_hostsAllowed.contains(remoteAddr)) {
217                            return true;
218                    }
219    
220                    String computerAddress = PortalUtil.getComputerAddress();
221    
222                    if (computerAddress.equals(remoteAddr) &&
223                            _hostsAllowed.contains(_SERVER_IP)) {
224    
225                            return true;
226                    }
227    
228                    return false;
229            }
230    
231            @Override
232            protected void processFilter(
233                            HttpServletRequest request, HttpServletResponse response,
234                            FilterChain filterChain)
235                    throws Exception {
236    
237                    String remoteAddr = request.getRemoteAddr();
238    
239                    if (isAccessAllowed(request)) {
240                            if (_log.isDebugEnabled()) {
241                                    _log.debug("Access allowed for " + remoteAddr);
242                            }
243                    }
244                    else {
245                            if (_log.isWarnEnabled()) {
246                                    _log.warn("Access denied for " + remoteAddr);
247                            }
248    
249                            response.sendError(
250                                    HttpServletResponse.SC_FORBIDDEN,
251                                    "Access denied for " + remoteAddr);
252    
253                            return;
254                    }
255    
256                    if (_log.isDebugEnabled()) {
257                            if (_httpsRequired) {
258                                    _log.debug("https is required");
259                            }
260                            else {
261                                    _log.debug("https is not required");
262                            }
263                    }
264    
265                    if (_httpsRequired && !request.isSecure()) {
266                            if (_log.isDebugEnabled()) {
267                                    String completeURL = HttpUtil.getCompleteURL(request);
268    
269                                    _log.debug("Securing " + completeURL);
270                            }
271    
272                            StringBundler redirectURL = new StringBundler(5);
273    
274                            redirectURL.append(Http.HTTPS_WITH_SLASH);
275                            redirectURL.append(request.getServerName());
276                            redirectURL.append(request.getServletPath());
277    
278                            String queryString = request.getQueryString();
279    
280                            if (Validator.isNotNull(queryString)) {
281                                    redirectURL.append(StringPool.QUESTION);
282                                    redirectURL.append(request.getQueryString());
283                            }
284    
285                            if (_log.isDebugEnabled()) {
286                                    _log.debug("Redirect to " + redirectURL);
287                            }
288    
289                            response.sendRedirect(redirectURL.toString());
290                    }
291                    else {
292                            if (_log.isDebugEnabled()) {
293                                    String completeURL = HttpUtil.getCompleteURL(request);
294    
295                                    _log.debug("Not securing " + completeURL);
296                            }
297    
298                            User user = PortalUtil.getUser(request);
299    
300                            if ((user != null) && !user.isDefaultUser()) {
301                                    request = setCredentials(
302                                            request, request.getSession(), user.getUserId(), null);
303                            }
304                            else {
305                                    if (_digestAuthEnabled) {
306                                            request = digestAuth(request, response);
307                                    }
308                                    else if (_basicAuthEnabled) {
309                                            request = basicAuth(request, response);
310                                    }
311                            }
312    
313                            if (request != null) {
314                                    processFilter(getClass(), request, response, filterChain);
315                            }
316                    }
317            }
318    
319            protected HttpServletRequest setCredentials(
320                            HttpServletRequest request, HttpSession session, long userId,
321                            String authType)
322                    throws Exception {
323    
324                    User user = UserLocalServiceUtil.getUser(userId);
325    
326                    String userIdString = String.valueOf(userId);
327    
328                    request = new ProtectedServletRequest(request, userIdString, authType);
329    
330                    session.setAttribute(WebKeys.USER, user);
331                    session.setAttribute(_AUTHENTICATED_USER, userIdString);
332    
333                    initThreadLocals(request);
334    
335                    return request;
336            }
337    
338            protected void setUsePermissionChecker(boolean usePermissionChecker) {
339                    _usePermissionChecker = usePermissionChecker;
340            }
341    
342            private static final String _AUTHENTICATED_USER =
343                    SecureFilter.class + "_AUTHENTICATED_USER";
344    
345            private static final String _BASIC_REALM =
346                    "Basic realm=\"" + Portal.PORTAL_REALM + "\"";
347    
348            private static final String _DIGEST_REALM =
349                    "Digest realm=\"" + Portal.PORTAL_REALM + "\"";
350    
351            private static final String _SERVER_IP = "SERVER_IP";
352    
353            private static Log _log = LogFactoryUtil.getLog(SecureFilter.class);
354    
355            private boolean _basicAuthEnabled;
356            private boolean _digestAuthEnabled;
357            private Set<String> _hostsAllowed = new HashSet<String>();
358            private boolean _httpsRequired;
359            private boolean _usePermissionChecker;
360    
361    }