Weighted Random Selection

  • Given an array of items, each with a corresponding weight
  • Implement a function that randomly selects an item from the array, where the probability of selecting any item is proportional to its weight.
    • => the probability of picking the item at index i is weights[i] / sum(weights)
  • Return the index of the selected item

The point:

  • Store only the end points (n values) of each segment p60
  • the endpoint of a segment is equal to the length of all previous segments, plus the length of the current segment
  • see prefix_sums[]

Checklist

  • 1 - Sorted Search space
    • [0, n-1]
  • 2 - Narrow search space
    • p 62
    • we’re looking for the lower-bound prefix sum which satisfies the condition prefix_sums[mid] ≥ target
  • 3 - Choose an exit condition for the while loop
    • while left < right
  • 4 - Return the correct value
    • left

Complexity :

Time Space
O(n) constructor  
O(log(n)) select O(1)
  • 0(n) because iteration over each weight
  • O(log(n)) because the search space is of size n
  • O(1) because in place

About Rust :

// # Cargo.toml
// [dependencies]
// rand = "0.9.1"

// If in a Jupyter cell
// :dep rand = "0.9.1"

// No reference to rand needed in Rust playground

use rand::Rng;

struct WeightedRandomSelection {
    prefix_sums: Vec<i32>,
}

impl WeightedRandomSelection {
    fn new(weights: Vec<i32>) -> Self {
        let mut prefix_sums = Vec::with_capacity(weights.len());
        prefix_sums.push(weights[0]);
        for i in 1..weights.len() {
            prefix_sums.push(prefix_sums[i - 1] + weights[i]);
        }
        Self { prefix_sums }
    }

    fn select(&self) -> usize {
        let mut rng = rand::rng();
        // Pick a random target between 1 and the largest possible endpoint.
        // let target = rng.random_range(1..=self.prefix_sums.last().unwrap().clone());
        let target = rng.random_range(1..=*self.prefix_sums.last().unwrap());   // dereferencing avoid cloning
        
        let (mut left, mut right) = (0, self.prefix_sums.len() - 1);
        // Perform lower-bound binary search to find which endpoint corresponds to the target.
        while left < right {
            let mid = (left + right) / 2;
            if self.prefix_sums[mid] < target {
                left = mid + 1;
            } else {
                right = mid;
            }
        }
        left
    }
}

fn main(){     // no main() if this code runs in a Jupyter cell 
    let wrs = WeightedRandomSelection::new(vec![3, 1, 2, 4]);
    let mut count = 0;
    for _ in 0..1_000 {
        if wrs.select() == 0 {
            count += 1;
        }
    }
    println!("{}", count); // Around 300 expected 300 (30% of 1_000)
}
261

Back to top

Published on: Jun 22 2025 at 09:00 AM | Last updated: Jun 22 2025 at 09:00 AM

Copyright © 1964-2025 - 40tude