springboot + shiro integrate redis to solve frequent access to redis and update session

background:

Regarding frequent access to redis , there are two situations. The first is to frequently go to redis to read the session ; the other is to frequently update the session in redis . For these two situations, write the corresponding solutions. Program.

The first case: frequently go to redis to read the session

There are two solutions for the first case, using local cache and getting from request . The two solutions are introduced below:

Solution 1: Use local cache

In the previous article, we used the  RedisSessionDAO class, which relies on a class called SessionInMemory , which is the shiro-redis author's solution for frequent access to redis to read the session for a single request , based on the local cache , if it is All requests within one second will get the request from the local cache . Let's look at this piece of code again:

public class RedisSessionDAO extends AbstractSessionDAO {     private static Logger logger = LoggerFactory.getLogger(RedisSessionDAO.class);     private static final String DEFAULT_SESSION_KEY_PREFIX = "shiro:session:";    private String keyPrefix = DEFAULT_SESSION_KEY_PREFIX;     private static final long DEFAULT_SESSION_IN_MEMORY_TIMEOUT = 1000L;    /**     * doReadSession be called about 10 times when login.     * Save Session in ThreadLocal to resolve this problem. sessionInMemoryTimeout is expiration of Session in ThreadLocal.     * The default value is 1000 milliseconds (1s).     * Most of time, you don't need to change it.     */    private long sessionInMemoryTimeout = DEFAULT_SESSION_IN_MEMORY_TIMEOUT;     /**     * expire time in seconds     */    private static final int DEFAULT_EXPIRE = -2;    private static final int NO_EXPIRE = -1;     /**     * Please make sure expire is longer than sesion.getTimeout()     */    private int expire = DEFAULT_EXPIRE;     private static final int MILLISECONDS_IN_A_SECOND = 1000;     private RedisManager redisManager;    private static ThreadLocal sessionsInThread = new ThreadLocal();     @Override    public void update(Session session) throws UnknownSessionException {        // 如果会话过期/停止 没必要再更新了        try {            if (session instanceof ValidatingSession && !((ValidatingSession) session).isValid()) {                return;            }             if (session instanceof ShiroSession) {                // 如果没有主要字段(除lastAccessTime以外其他字段)发生改变                ShiroSession ss = (ShiroSession) session;                if (!ss.isChanged()) {                    return;                }                // 如果没有返回 证明有调用 setAttribute往redis 放的时候永远设置为false                ss.setChanged(false);            }             this.saveSession(session);        } catch (Exception e) {            logger.warn("update Session is failed", e);        }    }     /**     * save session     * @param session     * @throws UnknownSessionException     */    private void saveSession(Session session) throws UnknownSessionException {        if (session == null || session.getId() == null) {            logger.error("session or session id is null");            throw new UnknownSessionException("session or session id is null");        }        String key = getRedisSessionKey(session.getId());        if (expire == DEFAULT_EXPIRE) {            this.redisManager.set(key, session, (int) (session.getTimeout() / MILLISECONDS_IN_A_SECOND));            return;        }        if (expire != NO_EXPIRE && expire * MILLISECONDS_IN_A_SECOND < session.getTimeout()) {            logger.warn("Redis session expire time: "                    + (expire * MILLISECONDS_IN_A_SECOND)                    + " is less than Session timeout: "                    + session.getTimeout()                    + " . It may cause some problems.");        }        this.redisManager.set(key, session, expire);    }     @Override    public void delete(Session session) {        if (session == null || session.getId() == null) {            logger.error("session or session id is null");            return;        }        try {            redisManager.del(getRedisSessionKey(session.getId()));        } catch (Exception e) {            logger.error("delete session error. session id= {}",session.getId());        }    }     @Override    public Collection<Session> getActiveSessions() {        Set<Session> sessions = new HashSet<Session>();        try {            Set<String> keys = redisManager.scan(this.keyPrefix + "*");            if (keys != null && keys.size() > 0) {                for (String key:keys) {                    Session s = (Session) redisManager.get(key);                    sessions.add(s);                }            }        } catch (Exception e) {            logger.error("get active sessions error.");        }        return sessions;    }     public Long getActiveSessionsSize() {        Long size = 0L;        try {            size = redisManager.scanSize(this.keyPrefix + "*");        } catch (Exception e) {            logger.error("get active sessions error.");        }        return size;    }     @Override    protected Serializable doCreate(Session session) {        if (session == null) {            logger.error("session is null");            throw new UnknownSessionException("session is null");        }        Serializable sessionId = this.generateSessionId(session);        this.assignSessionId(session, sessionId);        this.saveSession(session);        return sessionId;    }     @Override    protected Session doReadSession(Serializable sessionId) {        if (sessionId == null) {            logger.warn("session id is null");            return null;        }        Session s = getSessionFromThreadLocal(sessionId);         if (s != null) {            return s;        }         logger.debug("read session from redis");        try {            s = (Session) redisManager.get(getRedisSessionKey(sessionId));            setSessionToThreadLocal(sessionId, s);        } catch (Exception e) {            logger.error("read session error. settionId= {}",sessionId);        }        return s;    }        // 将 session 存入到 ThredLocal 中    private void setSessionToThreadLocal(Serializable sessionId, Session s) {        Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();        if (sessionMap == null) {            sessionMap = new HashMap<Serializable, SessionInMemory>();            sessionsInThread.set(sessionMap);        }        SessionInMemory sessionInMemory = new SessionInMemory();        sessionInMemory.setCreateTime(new Date());        sessionInMemory.setSession(s);        sessionMap.put(sessionId, sessionInMemory);    }    // 获取 session     private Session getSessionFromThreadLocal(Serializable sessionId) {        Session s = null;         if (sessionsInThread.get() == null) {            return null;        }         Map<Serializable, SessionInMemory> sessionMap = (Map<Serializable, SessionInMemory>) sessionsInThread.get();        SessionInMemory sessionInMemory = sessionMap.get(sessionId);        if (sessionInMemory == null) {            return null;        }        Date now = new Date();        long duration = now.getTime() - sessionInMemory.getCreateTime().getTime();        // 判断请求的时间差,若时间差小于设定的时间,则从本地缓存中获取        if (duration < sessionInMemoryTimeout) {            s = sessionInMemory.getSession();            logger.debug("read session from memory");        } else {            sessionMap.remove(sessionId);        }         return s;    }     private String getRedisSessionKey(Serializable sessionId) {        return this.keyPrefix + sessionId;    }     public RedisManager getRedisManager() {        return redisManager;    }     public void setRedisManager(RedisManager redisManager) {        this.redisManager = redisManager;    }     public String getKeyPrefix() {        return keyPrefix;    }     public void setKeyPrefix(String keyPrefix) {        this.keyPrefix = keyPrefix;    }     public long getSessionInMemoryTimeout() {        return sessionInMemoryTimeout;    }     public void setSessionInMemoryTimeout(long sessionInMemoryTimeout) {        this.sessionInMemoryTimeout = sessionInMemoryTimeout;    }     public int getExpire() {        return expire;    }     public void setExpire(int expire) {        this.expire = expire;    }}

 Look at the code composition of SessionInMemory , there are only two member variables, one is the current session , and the other is the creation time createTime .

import org.apache.shiro.session.Session; import java.util.Date; /** * Use ThreadLocal as a temporary storage of Session, so that shiro wouldn't keep read redis several times while a request coming. */public class SessionInMemory {    private Session session;    private Date createTime;     public Session getSession() {        return session;    }     public void setSession(Session session) {        this.session = session;    }     public Date getCreateTime() {        return createTime;    }     public void setCreateTime(Date createTime) {        this.createTime = createTime;    }}

Solution 2: Get session from request

Another better solution is to override  the retrieveSession() method of the DefaultWebSessionManager class . When shiro is used under the Web , this sessionKey is of type WebSessionKey , and this class has an attribute that we are familiar with: servletRequest . We can directly put the session object into the request ! Then in a single request cycle, we can all get the session from the request , and the request is destroyed after the request ends , and the scope and life cycle issues do not need us to consider. So we need Override the retrieveSession() method, for this we need to use a custom SessionManager , as follows:

import org.apache.shiro.session.Session;import org.apache.shiro.session.UnknownSessionException;import org.apache.shiro.session.mgt.SessionKey;import org.apache.shiro.web.session.mgt.DefaultWebSessionManager;import org.apache.shiro.web.session.mgt.WebSessionKey;import org.slf4j.Logger;import org.slf4j.LoggerFactory; import javax.servlet.ServletRequest;import java.io.Serializable; /** * @description: 解决单次请求需要多次访问redis */public class ShiroSessionManager extends DefaultWebSessionManager {     private static Logger logger = LoggerFactory.getLogger(DefaultWebSessionManager.class);    /**     * 获取session     * 优化单次请求需要多次访问redis的问题     * @param sessionKey     * @return     * @throws UnknownSessionException     */    @Override    protected Session retrieveSession(SessionKey sessionKey) throws UnknownSessionException {        Serializable sessionId = getSessionId(sessionKey);         ServletRequest request = null;        if (sessionKey instanceof WebSessionKey) {            request = ((WebSessionKey) sessionKey).getServletRequest();        }         if (request != null && null != sessionId) {            Object sessionObj = request.getAttribute(sessionId.toString());            if (sessionObj != null) {                logger.debug("read session from request");                return (Session) sessionObj;            }        }         Session session = super.retrieveSession(sessionKey);        if (request != null && null != sessionId) {            request.setAttribute(sessionId.toString(), session);        }        return session;    }}

Also you need to remember ShiroConfig configuration SessionManager custom of ShiroSessionManager .

The second case: frequently update the session in redis

When the session data changes, the session in redis will be updated , but in most cases only the LastAccessTime (last access time) field in the session will  change . Because redis in the  session failed to achieve by its outdated data, so redis only update LastAccessTime little significance in this field, but increased redis pressure. In order to reduce access to redis and reduce network pressure, when only this field changes, the session in redis is not updated . If there is a change in other fields except the  LastAccessTime field in the session . At this time, we can add a flag, and only let redis update when the flag is modified  , otherwise it will return directly.

We need to put a layer on  SimpleSession and add a flag  isChanged . The specific code is as follows:

import org.apache.shiro.session.mgt.SimpleSession; import java.io.Serializable;import java.util.Date;import java.util.Map; /** * 由于SimpleSession lastAccessTime更改后也会调用SessionDao update方法, * 增加标识位,如果只是更新lastAccessTime SessionDao update方法直接返回 */public class ShiroSession extends SimpleSession implements Serializable {    // 除lastAccessTime以外其他字段发生改变时为true    private boolean isChanged = false;     public ShiroSession() {        super();        this.setChanged(true);    }     public ShiroSession(String host) {        super(host);        this.setChanged(true);    }      @Override    public void setId(Serializable id) {        super.setId(id);        this.setChanged(true);    }     @Override    public void setStopTimestamp(Date stopTimestamp) {        super.setStopTimestamp(stopTimestamp);        this.setChanged(true);    }     @Override    public void setExpired(boolean expired) {        super.setExpired(expired);        this.setChanged(true);    }     @Override    public void setTimeout(long timeout) {        super.setTimeout(timeout);        this.setChanged(true);    }     @Override    public void setHost(String host) {        super.setHost(host);        this.setChanged(true);    }     @Override    public void setAttributes(Map<Object, Object> attributes) {        super.setAttributes(attributes);        this.setChanged(true);    }     @Override    public void setAttribute(Object key, Object value) {        super.setAttribute(key, value);        this.setChanged(true);    }     @Override    public Object removeAttribute(Object key) {        this.setChanged(true);        return super.removeAttribute(key);    }     /**     * 停止     */    @Override    public void stop() {        super.stop();        this.setChanged(true);    }     /**     * 设置过期     */    @Override    protected void expire() {        this.stop();        this.setExpired(true);    }     public boolean isChanged() {        return isChanged;    }     public void setChanged(boolean isChanged) {        this.isChanged = isChanged;    }     @Override    public boolean equals(Object obj) {        return super.equals(obj);    }     @Override    protected boolean onEquals(SimpleSession ss) {        return super.onEquals(ss);    }     @Override    public int hashCode() {        return super.hashCode();    }     @Override    public String toString() {        return super.toString();    }}

Write the class  ShiroSessionFactory to implement the  SessionFactory interface and implement the  createSession() method. The code is as follows:

import org.apache.shiro.session.Session;import org.apache.shiro.session.mgt.SessionContext;import org.apache.shiro.session.mgt.SessionFactory;import org.apache.shiro.web.session.mgt.DefaultWebSessionContext;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import org.springframework.util.StringUtils; import com.cache.ShiroSession; import javax.servlet.http.HttpServletRequest;  public class ShiroSessionFactory implements SessionFactory {    private static final Logger logger = LoggerFactory.getLogger(ShiroSessionFactory.class);     @Override    public Session createSession(SessionContext initData) {        ShiroSession session = new ShiroSession();        HttpServletRequest request = (HttpServletRequest)initData.get(DefaultWebSessionContext.class.getName() + ".SERVLET_REQUEST");        session.setHost(getIpAddress(request));        return session;    }     public static String getIpAddress(HttpServletRequest request) {        String localIP = "127.0.0.1";        String ip = request.getHeader("x-forwarded-for");        if (StringUtils.isEmpty(ip) || (ip.equalsIgnoreCase(localIP)) || "unknown".equalsIgnoreCase(ip)) {            ip = request.getHeader("Proxy-Client-IP");        }        if (StringUtils.isEmpty(ip) || (ip.equalsIgnoreCase(localIP)) || "unknown".equalsIgnoreCase(ip)) {            ip = request.getHeader("WL-Proxy-Client-IP");        }        if (StringUtils.isEmpty(ip) || (ip.equalsIgnoreCase(localIP)) || "unknown".equalsIgnoreCase(ip)) {            ip = request.getRemoteAddr();        }        return ip;    }}

The  ShiroSessionFactory class configuration to ShiroConfig , the assignment to remember the SessionManager , code is as follows:

    @Bean("sessionManager")	public SessionManager sessionManager() {		// ....        ShiroSessionManager sessionManager =  new ShiroSessionManager();		sessionManager.setSessionFactory(sessionFactory());        // ....    }     @Bean	public ShiroSessionFactory sessionFactory(){	    ShiroSessionFactory sessionFactory = new ShiroSessionFactory();	    return sessionFactory;	}

Finally , it is judged on the update method of RedisSessionDAO that if only the lastAccessTime field of the session is changed , it will return directly. code show as below:

    @Override    public void update(Session session) throws UnknownSessionException {        // 如果会话过期/停止 没必要再更新了        try {            if (session instanceof ValidatingSession && !((ValidatingSession) session).isValid()) {                return;            }             if (session instanceof ShiroSession) {                // 如果没有主要字段(除lastAccessTime以外其他字段)发生改变                ShiroSession ss = (ShiroSession) session;                if (!ss.isChanged()) {                    return;                }                // 如果没有返回 证明有调用 setAttribute往redis 放的时候永远设置为false                ss.setChanged(false);            }             this.saveSession(session);        } catch (Exception e) {            logger.warn("update Session is failed", e);        }    }

Note here: When operating redis to update the session , the changed attribute must be false . If you only change the lastAccessTime, it will not return directly, because the value taken from redis is true . So, now that we have reached the step of updating the session in redis , methods such as setAttributes() must be called. So set it to false when putting it in redis . The next time that the session obtained from redis is false , only the lastAccessTime will be changed , then the changed attribute will be false, and redis will not be operated .