使用场景

存在一批大容量的数据(或数据流),无法一次性加载进内存或硬盘内,需对这批数据进行随机采样。

原理

假设需要采样的数量为 k:

  1. 首先构建长度为 k 的数组,将序列的前 k 个元素放入数组中;
  2. 对于从第  j 个元素(j > k) ,以  k / j的概率来决定该元素是否被替换到数组中,数组中的 k 个元素被替换的概率是相同的;
  3. 当遍历完所有元素之后,数组中剩下的元素即为采样样本。

为什么这样是随机的

蓄水池抽样算法的核心在于先以某一种概率选取数,并在后续过程以另一种概率换掉之前已经被选中的数。因此实际上每个数被最终选中的概率都是被选中的概率 * 不被替换的概率

代码实现

public class ReservoirSamplingTest {
 
    private int[] pool; // 所有数据
    private final int N = 100000; // 数据规模
    private Random random = new Random();
 
    @Before
    public void setUp() throws Exception {
        // 初始化
        pool = new int[N];
        for (int i = 0; i < N; i++) {
            pool[i] = i;
        }
    }
 
    private int[] sampling(int K) {
        int[] result = new int[K];
        
        for (int i = 0; i < K; i++) { // 前 K 个元素直接放入数组中
            result[i] = pool[i];
        }
 
        for (int i = K; i < N; i++) { // K + 1 个元素开始进行概率采样
            int r = random.nextInt(i + 1);
            if (r < K) {
                result[r] = pool[i];
            }
        }
 
        return result;
    }
 
    @Test
    public void test() throws Exception {
        for (int i : sampling(100)) {
            System.out.println(i);
        }
    }
}

参考资料

蓄水池采样算法