我正在寻找一个合理的函数定义,该函数weighted_sample不会为给定的权重列表仅返回一个随机索引(这类似于
weighted_sample
def weighted_choice(weights, random=random): """ Given a list of weights [w_0, w_1, ..., w_n-1], return an index i in range(n) with probability proportional to w_i. """ rnd = random.random() * sum(weights) for i, w in enumerate(weights): if w<0: raise ValueError("Negative weight encountered.") rnd -= w if rnd < 0: return i raise ValueError("Sum of weights is not positive")
以得到具有恒定的权重一个分类分布),但的随机样本k的那些, 不用更换 ,就像random.sample相比的行为random.choice。
k
random.sample
random.choice
就像weighted_choice可以写成
weighted_choice
lambda weights: random.choice([val for val, cnt in enumerate(weights) for i in range(cnt)])
weighted_sample 可以写成
lambda weights, k: random.sample([val for val, cnt in enumerate(weights) for i in range(cnt)], k)
但是我想要一个不需要我将权重分解为一个(可能很大)列表的解决方案。
编辑:如果有任何不错的算法可以给我返回频率的直方图/频率列表(与参数格式相同weights),而不是索引序列,那也将非常有用。
weights
从您的代码:..
weight_sample_indexes = lambda weights, k: random.sample([val for val, cnt in enumerate(weights) for i in range(cnt)], k)
..我假设权重是正整数,并且“无替代”是指没有替代解散的序列。
这是一个基于random.sample和O(log n)的解决方案__getitem__:
__getitem__
import bisect import random from collections import Counter, Sequence def weighted_sample(population, weights, k): return random.sample(WeightedPopulation(population, weights), k) class WeightedPopulation(Sequence): def __init__(self, population, weights): assert len(population) == len(weights) > 0 self.population = population self.cumweights = [] cumsum = 0 # compute cumulative weight for w in weights: cumsum += w self.cumweights.append(cumsum) def __len__(self): return self.cumweights[-1] def __getitem__(self, i): if not 0 <= i < len(self): raise IndexError(i) return self.population[bisect.bisect(self.cumweights, i)]
total = Counter() for _ in range(1000): sample = weighted_sample("abc", [1,10,2], 5) total.update(sample) print(sample) print("Frequences %s" % (dict(Counter(sample)),)) # Check that values are sane print("Total " + ', '.join("%s: %.0f" % (val, count * 1.0 / min(total.values())) for val, count in total.most_common()))
['b', 'b', 'b', 'c', 'c'] Frequences {'c': 2, 'b': 3} Total b: 10, c: 2, a: 1