package win.liyufan.im;
import java.util.HashMap;import java.util.Iterator;import java.util.Map;
/** * 漏桶算法 */public class RateLimiter { private static final int DEFAULT_LIMIT_TIME_SECOND = 5; private static final int DEFAULT_LIMIT_COUNT = 100; private static final long expire = 2 * 60 * 60 * 1000; /** * 允许的请求速率,默认20/s,即,漏桶以该速率速率流出, */ private double rate = (double) DEFAULT_LIMIT_COUNT / (DEFAULT_LIMIT_TIME_SECOND); /** * 最大请求次数, * 1000 是为了单位对齐, 漏桶算法的实现是按照毫秒为单位, */ private long capacity = DEFAULT_LIMIT_COUNT * 1000;
/** * 最后请求时间 */ private long lastCleanTime;
/** * 记录用户的请求次数 */ private Map<String, Long> requestCountMap = new HashMap<>();
/** * 记录用户的请求时间 */ private Map<String, Long> requestTimeMap = new HashMap<>();
/** * cas自旋锁 */ private SpinLock lock = new SpinLock();
public RateLimiter() {
}
/** * 构造一个限流器,指定每秒运行多少个请求 * @param limitTimeSecond * @param limitCount */ public RateLimiter(int limitTimeSecond, int limitCount) { if (limitTimeSecond <= 0 || limitCount <= 0) { throw new IllegalArgumentException(); } // 2000 this.capacity = limitCount * 1000; // 2 this.rate = (double) limitCount / limitTimeSecond; }
/** * 漏桶算法,https://en.wikipedia.org/wiki/Leaky_bucket */ public boolean isGranted(String userId) { try { lock.lock(); long current = System.currentTimeMillis(); cleanUp(current); Long lastRequestTime = requestTimeMap.get(userId); long count = 0; if (lastRequestTime == null) { count += 1000; requestTimeMap.put(userId, current); requestCountMap.put(userId, count); return true; } else { count = requestCountMap.get(userId); lastRequestTime = requestTimeMap.get(userId); // 漏桶流出 count -= (current - lastRequestTime) * rate; count = count > 0 ? count : 0; requestTimeMap.put(userId, current); if (count < capacity) { count += 1000; requestCountMap.put(userId, count); return true; } else { requestCountMap.put(userId, count); return false; } } } finally { lock.unLock(); } }
private void cleanUp(long current) { // 过期时间2个小时,达到过期时间,删除requestCountMap if (current - lastCleanTime > expire) { for (Iterator<Map.Entry<String, Long>> it = requestTimeMap.entrySet().iterator(); it.hasNext();) { Map.Entry<String, Long> entry = it.next(); if (entry.getValue() < current - expire) { it.remove(); requestCountMap.remove(entry.getKey()); } } lastCleanTime = current; } }}
// 自旋锁代码public class SpinLock { //java中原子(CAS)操作 AtomicReference<Thread> owner = new AtomicReference<>();//持有自旋锁的线程对象 private int count; public void lock() { Thread cur = Thread.currentThread(); //lock函数将owner设置为当前线程,并且预测原来的值为空。unlock函数将owner设置为null,并且预测值为当前线程。当有第二个线程调用lock操作时由于owner值不为空,导致循环
//一直被执行,直至第一个线程调用unlock函数将owner设置为null,第二个线程才能进入临界区。 while (!owner.compareAndSet(null, cur)){ } } public void unLock() { Thread cur = Thread.currentThread(); owner.compareAndSet(cur, null); }}
评论