146. LRU缓存机制

题目地址

题目描述

运用你所掌握的数据结构,设计和实现一个 LRU (最近最少使用) 缓存机制。它应该支持以下操作: 获取数据 get 和 写入数据 put 。

获取数据 get(key) - 如果密钥 (key) 存在于缓存中,则获取密钥的值(总是正数),否则返回 -1。
写入数据 put(key, value) - 如果密钥已经存在,则变更其数据值;如果密钥不存在,则插入该组「密钥/数据值」。当缓存容量达到上限时,它应该在写入新数据之前删除最久未使用的数据值,从而为新的数据值留出空间。

进阶:

你是否可以在 O(1) 时间复杂度内完成这两种操作?

示例:

LRUCache cache = new LRUCache( 2 /* 缓存容量 */ );

cache.put(1, 1);
cache.put(2, 2);
cache.get(1);       // 返回  1
cache.put(3, 3);    // 该操作会使得密钥 2 作废
cache.get(2);       // 返回 -1 (未找到)
cache.put(4, 4);    // 该操作会使得密钥 1 作废
cache.get(1);       // 返回 -1 (未找到)
cache.get(3);       // 返回  3
cache.get(4);       // 返回  4

解法

public class LRUCache  extends LinkedHashMap<Integer,Integer>{

    int size;
    public LRUCache(int capacity) {
        super(capacity, 0.75F, true);
        size=capacity;
    }

    public int get(int key) {
        return super.getOrDefault(key, -1);
    }

    public void put(int key, int value) {
        super.put(key, value);
    }

    @Override
    protected boolean removeEldestEntry(Map.Entry<Integer, Integer> eldest) {
        return size() > size;
    }
}

解题思路

LinkedHashMap

LinkedHashMap继承于HashMap,并且实现了hashMap三个回调函数

  • afterNodeInsertion(boolean evict) { }

在put插入元素之后,是否移除多余的元素,这里是主要根据可以

removeEldestEntry(Map.Entry<K,V> eldest)作为条件

这个方法是protected,所以我们可以重写,自定义移除的条件

  • afterNodeRemoval(Node<K,V> p) { }
    在移除元素之后,将元素从链表中移除

  • afterNodeAccess(Node<K,V> p) { }

在访问元素之后,将该元素放到双向链表的末尾,这里分访问包括两种情况,getput时的更新值。

详情的方法描述可以查看 源于 LinkedHashMap源码的题解。
除了以上三个回调函数外,其实还有一个重要的方法

  • linkNodeLast(LinkedHashMap.Entry<K,V> p)

LinkedHaslinkedhMap重写了put方法,在新建对象的时,会调用linkNodeLast将新增的结点关联到当前链表的尾部或者头部(如果不存在链表)

你以为今天的题解就这么结束了?

手写Map+链表实现LRU

既然hashMap留下三个空方法,那我们为什么自己实现呢?有个坏消息,这三个方法属于default,只能在hashMap同一个包下才能被重写。
那我们就自己手写一个map,这里map指的是1.7版本的,数组加链表实现map(很多代码的方法之间从源码复制出来,自己写了一遍才觉得源码是多么优秀)

1.根据题目我们需要实现三个方法


 - get(Object key)
 - put(K key, V value)
 - getOrDefault(Object key, V defaultValue)

这里第三个方法基于get方法实现,可以直接

 public V getOrDefault(Object key, V defaultValue) {
            V value;
            if ((value = get(key)) == null) {
                return defaultValue;
            }
            return value;
}

2.定义完方法后,我们需要定义几个变量

//存放值得数组
Entry<K, V>[] table;
//数组的长度
int modCount;
//数组使用的个数
int size; 
//链表头
Entry<K, V> head;
//链表尾
Entry<K, V> tail;

3.定义map的Entry对象

     class Entry<K, V> {
     //hash值
            int hash;
     // key
            K key;
    //值
            V value;
    //指向下个对象
            Entry<K,V> next;
    //链表的上一个指向和下一个指向
            Entry<K, V> before, after;
    //  构造方法
            public Entry() {
            }


            public Entry(int hash, K key, V value, Entry<K, V> node) {
                this.hash = hash;
                this.key = key;
                this.value = value;
                this.next = node;
            }
        }

在这里为什么我们定义了三个Entry<K,V>对象next、before, after,在LinkedHashMap的方法中

static class Entry<K,V> extends HashMap.Node<K,V> {
        Entry<K,V> before, after;
        Entry(int hash, K key, V value, Node<K,V> next) {
            super(hash, key, value, next);
        }
    }

可以看的LinkedHashMap也是在HashMap.Node<K,V>之前进行了封装了Entry对象,GET这种写法,由于Entry方法属于内部类,无法访问,这里就定义了三个Entry<K,V>对象.

4. 构造方法,初始化参数

这里,我们简单实现满足题目的要求,所以直接舍弃了扩容方法

public MyLinkedHashMap(int initialCapacity) {
			//定义数组的长度
            this.table = new Entry[initialCapacity];
            //初始化长度
            modCount = initialCapacity;
            //初始化使用的长度
            size=0;
}

5.实现put方法

这里便是我们的重量级方法的第一个put,既然要实现HashMap,那就必须要实现hash值
以下得代码,是我从源码复制过来了,当然我也是手写了一遍,之后发现问题比较多,还是从源码复制过来,这里不禁感慨源码写的真好!!!!

final int hash(Object key) {
int h;
return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16);
 }
public V put(K key, V value) {
            Integer hash = hash(key);
            Entry<K, V> p;
            int n, i;
            n = modCount;
            // (n - 1) & hash 位运算得到下标
            if ((p = table[i = (n - 1) & hash]) == null) {
                table[i] = newNode(key,value,hash,null);
                size++;
            } else {
                Entry<K,V> e; K k;
                if (p.hash == hash &&
                        ((k = p.key) == key || (key != null && key.equals(k))))
                    e = p;
                else {
                //遍历存在下标一致,遍历链表,找到值或者添加在链表尾部
                    for (; ;) {
                        if ((e = p.next) == null) {
                            p.next = newNode(key, value, hash,null);
                            size++;
                            break;
                        }
                        if (e.hash == hash &&
                                ((k = e.key) == key || (key != null && key.equals(k))))
                            break;
                        p = e;
                    }
                }
                if (e != null) {
                    V oldValue = e.value;
                    e.value = value;
                    //更新值之后,更新LRU缓存
                    afterNodeAccess(e);
                    return oldValue;
                }
            }
            // 判断是否需要移除多余元素
            afterNodeInsertion();
            return null;
        }

这里我对几挑选几个重点的源码

table[i = (n - 1) & hash]

这里是1.8源码对数组下标的运算,在源码中,数组长度默认为2的幂次,然后与hash与运算,主要作用是使得下标分布均匀

 Entry<K,V> e; K k;
 if (p.hash == hash &&
((k = p.key) == key || (key != null && key.equals(k))))
 e = p;

在源码中会频繁看的在if添加判断中添加赋值的操作,这也是我这次在读源码收到的最大收获之一,没想到代码可以这么写!!!这里是对当前的存在数组的元素进行Hash值,key的匹配。

 if ((e = p.next) == null) {
 p.next = newNode(key, value, hash,null);
  size++;
  break;
 }

这里便体现了map底层是数组加链表

5.实现get方法

先贴上源码

 public V get(Object key) {
            Entry<K,V> e;
            if ((e = getNode(hash(key), key)) == null)
                return null;
                afterNodeAccess(e);
            return e.value;
        }
Entry<K,V> getNode(int hash, Object key) {
            Entry<K,V>[] tab; Entry<K,V> first, e; int n; K k;
            if ((tab = table) != null && (n = tab.length) > 0 &&
                    (first = tab[(n - 1) & hash]) != null) {
                if (first.hash == hash && 
                        ((k = first.key) == key || (key != null && key.equals(k))))
                    return first;
                if ((e = first.next) != null) {
                    do {
                        if (e.hash == hash &&
                                ((k = e.key) == key || (key != null && key.equals(k))))
                            return e;
                    } while ((e = e.next) != null);
                }
            }
            return null;
        }

这块大部分是源码部分,我去掉有关于treeNode的判断,从源码看出,hashMap实现中使用了很多
do whle()的操作,这里也是为什么1.8使用红黑树的原因了,如果链表过长,遍历的时间也会随之增加。
我们还是分析这里的操作

if (first.hash == hash && 
 ((k = first.key) == key || (key != null && key.equals(k))))
 return first;

默认从数组上取值,如果当前数组上的key+hash正好是我们需要的值,直接返回当前数组的值,
如果不是,则遍历当前数组的链表,直到链表尾部。

6.实现四个关于链表更新的方法

以上,我们基本上完成了对hashMap的实现,既然是LinkedHaslinkedhMap那我们就需要实现之前提到的四个关于链表的操作的方法
1.当我们访问某个节点时,更新节点的链表到尾部

  private void afterNodeAccess(Entry<K,V> e) {
            Entry<K,V> last;
            if ((last = tail) != e) {
                Entry<K,V> p =e, b = p.before, a = p.after;
                p.after = null;
                // 
                if (b == null)
                    head = a;
                else
                    b.after = a;
                //
                if (a != null)
                    a.before = b;
                else
                    last = b;
                if (last == null)
                    head = p;
                else {
                    p.before = last;
                    last.after = p;
                }
                tail = p;
            }
        }

afterNodeInsertion和removeNode、removeEldestEntry、afterNodeRemoval属于关联的操作

  • removeNode移除节点
  • afterNodeRemoval移除链表
  • removeEldestEntry 判断是否能移除元素
  private void afterNodeInsertion() {
  // 当前条件满足时移除,头元素--》最久未使用的数据值
            Entry<K,V> first;
            if ( (first = head) != null && removeEldestEntry()) {
                K key = first.key;
                removeNode(hash(key), key,null, false, true);
            }
        }
private boolean removeEldestEntry() {
//判断是否可以移除元素
    return size>modCount;
 }
//移除链表上该节点
      private void afterNodeRemoval(Entry<K,V> e) {
            Entry<K,V> p =e, b = p.before, a = p.after;
            p.before = p.after = null;
            if (b == null)
                head = a;
            else
                b.after = a;
            if (a == null)
                tail = b;
            else
                a.before = b;
        }
 Entry<K,V> removeNode(int hash, Object key, Object value,
                                   boolean matchValue, boolean movable) {
          Entry<K,V>[] tab; Entry<K,V> p; int n, index;
          if ((tab = table) != null && (n = tab.length) > 0 &&
                  (p = tab[index = (n - 1) & hash]) != null) {
              Entry<K,V> node = null, e; K k; V v;
              if (p.hash == hash &&
                      ((k = p.key) == key || (key != null && key.equals(k))))
                  node = p;
              else if ((e = p.next) != null) {
                      do {
                          if (e.hash == hash &&
                                  ((k = e.key) == key ||
                                          (key != null && key.equals(k)))) {
                              node = e;
                              break;
                          }
                          p = e;
                      } while ((e = e.next) != null);
              }
              if (node != null && (!matchValue || (v = node.value) == value ||
                      (value != null && value.equals(v)))) {
                   if (node == p)
                      tab[index] = node.next;
                  else
                      p.next = node.next;
                  --size;
                  afterNodeRemoval(node);
                  return node;
              }
          }
          return null;
      }

关于这里的判断

  if (node != null && (!matchValue || (v = node.value) == value ||
                      (value != null && value.equals(v)))) 

这我一开始写是 if (node != null){},当前在实际执行中,会出现size的值大于了modCount,所以直接复制了源码的部分的判断条件,这是我源码暂时未看懂的地方,需要继续花时间研究

/**
     * Implements Map.remove and related methods.
     *
     * @param hash hash值
     * @param key the key
     * @param value 匹配值
     * @param matchValue 如果为true,则仅在值相等时删除
     * @param movable 如果为false,则在删除时不要移动其他节点
     * @return the node, or null if none
     */

newNode 新建一个对象的同时,将新的节点更新(linkNodeLast)到链表的尾部

 Entry<K,V> newNode( K key, V value,Integer hash,Entry<K, V> next) {
            Entry<K,V> p= new Entry( hash,key, value,next);
            linkNodeLast(p);
            return p;
        }
 private void linkNodeLast(Entry<K,V> p){
            Entry<K,V> last=tail;
            tail=p;
            if(last==null){
                head=p;
            }else{
                p.before=last;
                last.after=p;
            }
        }

从这里代码可以看出LinkedHashMap时,head属于不活跃的结点,tai属于活跃结点,其中源码关于链表的操作写的十分的优雅,对链表的操作理解有很大的帮助。

7代码

public class LRUCache {
    private final MyLinkedHashMap<Integer, Integer> map;

    public LRUCache(int capacity) {
        map = new MyLinkedHashMap(capacity);
    }

    public int get(int key) {
        return map.getOrDefault(key, -1);
    }

    public void put(int key, int value) {
        map.put(key, value);
    }

    private class MyLinkedHashMap<K, V> {

        private final Entry<K, V>[] table;

        int modCount;

        int size;

        class Entry<K, V> {
            int hash;
            K key;
            V value;
            Entry<K,V> next;

            Entry<K, V> before, after;

            public Entry() {
            }


            public Entry(int hash, K key, V value, Entry<K, V> node) {
                this.hash = hash;
                this.key = key;
                this.value = value;
                this.next = node;
            }
        }

        Entry<K, V> head;

        Entry<K, V> tail;


        public MyLinkedHashMap(int initialCapacity) {
            this.table = new Entry[initialCapacity];
            modCount = initialCapacity;
            size=0;
        }

        final int hash(Object key) {
            int h;
            return (key == null) ? 0 : (h = key.hashCode()) ^ (h >>> 16);
        }

        Entry<K,V> newNode( K key, V value,Integer hash,Entry<K, V> next) {
            Entry<K,V> p= new Entry( hash,key, value,next);
            linkNodeLast(p);
            return p;
        }

        private void linkNodeLast(Entry<K,V> p){
            Entry<K,V> last=tail;
            tail=p;
            if(last==null){
                head=p;
            }else{
                p.before=last;
                last.after=p;
            }
        }
        public V get(Object key) {
            Entry<K,V> e;
            if ((e = getNode(hash(key), key)) == null)
                return null;
                afterNodeAccess(e);
            return e.value;
        }

         Entry<K,V> getNode(int hash, Object key) {
            Entry<K,V>[] tab; Entry<K,V> first, e; int n; K k;
            if ((tab = table) != null && (n = tab.length) > 0 &&
                    (first = tab[(n - 1) & hash]) != null) {
                if (first.hash == hash && // always check first node
                        ((k = first.key) == key || (key != null && key.equals(k))))
                    return first;
                if ((e = first.next) != null) {
                    do {
                        if (e.hash == hash &&
                                ((k = e.key) == key || (key != null && key.equals(k))))
                            return e;
                    } while ((e = e.next) != null);
                }
            }
            return null;
        }

        public V put(K key, V value) {
            Integer hash = hash(key);
            Entry<K, V> p;
            int n, i;
            n = modCount;
            if ((p = table[i = (n - 1) & hash]) == null) {
                table[i] = newNode(key,value,hash,null);
                size++;
            } else {
                Entry<K,V> e; K k;
                if (p.hash == hash &&
                        ((k = p.key) == key || (key != null && key.equals(k))))
                    e = p;
                else {
                    for (int binCount = 0; ; ++binCount) {
                        if ((e = p.next) == null) {
                            p.next = newNode(key, value, hash,null);
                            size++;
                            break;
                        }
                        if (e.hash == hash &&
                                ((k = e.key) == key || (key != null && key.equals(k))))
                            break;
                        p = e;
                    }
                }
                if (e != null) { // existing mapping for key
                    V oldValue = e.value;
                    e.value = value;
                    afterNodeAccess(e);
                    return oldValue;
                }
            }
            afterNodeInsertion();
            return null;
        }

        private void afterNodeInsertion() {
            Entry<K,V> first;
            if ( (first = head) != null && removeEldestEntry()) {
                K key = first.key;
                removeNode(hash(key), key,null, false, true);
            }
        }

      Entry<K,V> removeNode(int hash, Object key, Object value,
                                   boolean matchValue, boolean movable) {
          Entry<K,V>[] tab; Entry<K,V> p; int n, index;
          if ((tab = table) != null && (n = tab.length) > 0 &&
                  (p = tab[index = (n - 1) & hash]) != null) {
              Entry<K,V> node = null, e; K k; V v;
              if (p.hash == hash &&
                      ((k = p.key) == key || (key != null && key.equals(k))))
                  node = p;
              else if ((e = p.next) != null) {
                      do {
                          if (e.hash == hash &&
                                  ((k = e.key) == key ||
                                          (key != null && key.equals(k)))) {
                              node = e;
                              break;
                          }
                          p = e;
                      } while ((e = e.next) != null);
              }
              if (node != null && (!matchValue || (v = node.value) == value ||
                      (value != null && value.equals(v)))) {
                   if (node == p)
                      tab[index] = node.next;
                  else
                      p.next = node.next;
                  --size;
                  afterNodeRemoval(node);
                  return node;
              }
          }
          return null;
      }

        //从链表将该元素删除
        private void afterNodeRemoval(Entry<K,V> e) {
            Entry<K,V> p =e, b = p.before, a = p.after;
            p.before = p.after = null;
            if (b == null)
                head = a;
            else
                b.after = a;
            if (a == null)
                tail = b;
            else
                a.before = b;
        }
        private boolean removeEldestEntry() {
            return size>modCount;
        }

        private void afterNodeAccess(Entry<K,V> e) {
            Entry<K,V> last;
            if ((last = tail) != e) {
                Entry<K,V> p =e, b = p.before, a = p.after;
                p.after = null;
                if (b == null)
                    head = a;
                else
                    b.after = a;
                if (a != null)
                    a.before = b;
                else
                    last = b;
                if (last == null)
                    head = p;
                else {
                    p.before = last;
                    last.after = p;
                }
                tail = p;
            }
        }


        public V getOrDefault(Object key, V defaultValue) {
            V value;
            if ((value = get(key)) == null) {
                return defaultValue;
            }
            return value;
        }

    }
    public static void main(String[] args) {
        LRUCache cache = new LRUCache(2);
        cache.put(1, 1);
        cache.put(2, 2);

        int res1 = cache.get(1);
        System.out.println(res1);

        cache.put(3, 3);

        int res2 = cache.get(2);
        System.out.println(res2);

        int res3 = cache.get(3);
        System.out.println(res3);

        cache.put(4, 4);


        int res4 = cache.get(1);
        System.out.println(res4);

        int res5 = cache.get(3);
        System.out.println(res5);

        int res6 = cache.get(4);
        System.out.println(res6);
    }
}

8 总结

这次对源码的解析,一方面了解hashMap的底层实现,数组加链表,以及知道为什么1.8会加入红黑树,更学习到了更多操作,尤其是对链表的操作、在条件运算的时加入赋值的写法。