/**
* A player of a multi-armed bandit.
*
* @author Jim Glenn
* @version 0.1 1/18/2009
*/
public abstract class MultiArmedBanditPlayer
{
/**
* The machine this player plays.
*/
protected MultiArmedBandit machine;
/**
* Creates a player for the given machine.
*
* @param MultiArmedBandit m
*/
public MultiArmedBanditPlayer(MultiArmedBandit m)
{
machine = m;
}
/**
* Returns the arm that this player chooses to play.
*
* @return the index of the arm this player chooses
*/
public abstract int chooseArm();
/**
* Updates this player with the outcome of pulling the given arm.
* Subclasses should override this method to keep track of the history
* of the machine in order to determine which arm to play.
*
* @param arm the index of the arm played
* @param payout the amount earned on that play
*/
public abstract void update(int arm, double payout);
/**
* Returns the result of playing the given number of games.
* Each invocation of this method is cumulative, so the history
* of one invocation is still available during the next.
*/
public double play(int plays)
{
double total = 0.0;
for (int p = 0; p < plays; p++)
{
// choose the arm to play
int arm = chooseArm();
// play it
double payout = machine.play(arm);
// remember what happened
update(arm, payout);
total += payout;
}
return total;
}
}
import java.util.*;
/**
* Simulates a player playing a multi-armed bandit.
*
* @author Jim Glenn
* @version 0.1 1/18/2009
*/
public class MultiArmedBanditSimulation
{
public static void main(String[] args)
{
// defaults for command line arguments
final int DEFAULT_RUNS = 10;
final int DEFAULT_PLAYS_PER_RUN = 10000;
int numRuns = DEFAULT_RUNS;
int playsPerRun = DEFAULT_PLAYS_PER_RUN;
// parse command line arguments
try
{
if (args.length > 0)
{
numRuns = Integer.parseInt(args[0]);
if (args.length > 1)
{
playsPerRun = Integer.parseInt(args[1]);
}
}
if (numRuns < 0 || playsPerRun < 0)
{
throw new NumberFormatException();
}
}
catch (NumberFormatException e)
{
System.err.println("USAGE: java MultiArmedBanditSimulation [runs [plays-per-run]]");
System.exit(1);
}
// create the machine (expected value of random play: 0.853333...
// one arm with a consistent but modest payout
OneArmedBandit a1 = BinomialBandit.makeNormalApproximation(0.75, 0.005);
// one arm with a better but less consistent payout (avg 0.8)
OneArmedBandit a2 = new CompositeBandit(new BinomialBandit(4, 0.25),
new BinomialBandit(1, 11, 0.1));
// one arm with a rare windfall; expected reward > 1
OneArmedBandit a3 = new BinomialBandit(1, 101, 0.01);
OneArmedBandit[] arms = {a1, a2, a3};
// do multiple runs
double totalWinnings = 0.0;
for (int run = 0; run < numRuns; run++)
{
// randomly order the three arms for the multi-armed bandit
shuffle(arms);
MultiArmedBandit m = new MultiArmedBandit(arms);
// create the player and let it at the machine
MultiArmedBanditPlayer player = new RandomPlayer(m);
double runWinnings = player.play(playsPerRun);
System.out.println("Run " + run + ": " + runWinnings);
totalWinnings += runWinnings;
}
System.out.println("Overall average per play: "
+ totalWinnings / (numRuns * playsPerRun));
}
/**
* Randomly reorders the elements in the given array.
*
* @param arr an array of OneArmedBandits
*/
private static void shuffle(OneArmedBandit[] arr)
{
for (int i = 0; i < arr.length - 1; i++)
{
// choose an index to swap with index i
int swapIndex = (int)(Math.random() * (arr.length - i)) + i;
// swap
OneArmedBandit temp = arr[i];
arr[i] = arr[swapIndex];
arr[swapIndex] = temp;
}
}
}
import java.util.*;
/**
* A multi-armed bandit. Such a machine is a collection of one-armed
* bandits. Each arm may have a different payout distribution. The
* arms are referred to by consecutive integer indices that start with 0.
*
* @author Jim Glenn
* @version 0.1 1/18/2009
*/
public class MultiArmedBandit
{
/**
* A list that holds the one-armed bandits that make up this machine.
*/
private List< OneArmedBandit > arms;
/**
* Creates a multi-armed bandit with only one arm.
*
* @param arm a one-armed bandit
*/
public MultiArmedBandit(OneArmedBandit arm)
{
arms = new ArrayList< OneArmedBandit >();
arms.add(arm);
}
/**
* Creates a multi-armed bandit with two arms.
*
* @param a0 the arm that will be at index 0 of the new machine
* @param a1 the arm that will be at index 1 of the new machine
*/
public MultiArmedBandit(OneArmedBandit a0, OneArmedBandit a1)
{
arms = new ArrayList< OneArmedBandit >();
arms.add(a0);
arms.add(a1);
}
/**
* Creates a multi-armed bandit with three arms.
*
* @param a0 the arm that will be at index 0 of the new machine
* @param a1 the arm that will be at index 1 of the new machine
* @param a2 the arm that will be at index 2 of the new machine
*/
public MultiArmedBandit(OneArmedBandit a0,
OneArmedBandit a1,
OneArmedBandit a2)
{
arms = new ArrayList< OneArmedBandit >();
arms.add(a0);
arms.add(a1);
}
/**
* Creates a multi-armed bandit from the one-armed bandits in the given
* list. The indices into the list will be the indices of the arms
* in the new machine.
*
* @param l a list of one-armed bandits
*/
public MultiArmedBandit(List< OneArmedBandit > l)
{
// make a deep copy of the list
arms = new ArrayList< OneArmedBandit >(l);
}
/**
* Creates a multi-armed bandit from the one-armed bandits in the given
* array. The indices into the array will be the indices of the arms
* in the new machine.
*
* @param a an array of one-armed bandits
*/
public MultiArmedBandit(OneArmedBandit[] a)
{
// make a deep copy of the array; note that the list returned by
// as list is not a deep copy of a
arms = new ArrayList< OneArmedBandit >(Arrays.asList(a));
}
/**
* Returns the number of arms on this machine.
*
* @return the number of arms
*/
public int countArms()
{
return arms.size();
}
/**
* Returns the result of playing the given arm on this machine. Each
* invocation of this method represents a different play. Arms are
* given by consecutive integer indices that start at 0.
*
* @param i the index of an arm on this machine.
*/
public double play(int i)
{
return arms.get(i).play();
}
/**
* Returns the maximum payout possible from one play of this machine.
*
* @return the maximum payout
*/
public double getMaximum()
{
double max = Double.NEGATIVE_INFINITY;
for (OneArmedBandit arm : arms)
{
max = Math.max(arm.getMaximum(), max);
}
return max;
}
}
/**
* A gaming device that has a payout determined strictly by chance.
*
* @author Jim Glenn
* @version 0.1 1/16/2009
*/
public interface OneArmedBandit
{
/**
* Returns the payout for a play of 1.0 units. Each invocation of this
* method corresponds to one play, so implementing classes should
* randomize for each invocation.
*
* @return the payout, a nonnegative number
*/
public double play();
/**
* Returns the maximum possible payout from this machine.
*
* @return the maximum payout
*/
public double getMaximum();
}
/**
* A one-armed bandit with payout determined by a binomial
* distribution. Such a machine can be thought of as flipping n coins
* of the same value that each have probability p of coming up heads.
* The payout is the coins that came up heads.
*
* @author Jim Glenn
* @version 0.1 1/16/2009
*/
public class BinomialBandit implements OneArmedBandit
{
/**
* The number of coins this bandit flips.
*/
private int numCoins;
/**
* The payout for each coin that comes up heads.
*/
private double headValue;
/**
* The probability of each coin coming up heads.
*/
private double pHead;
/**
* Creates a binomial bandit that flips the given coins.
* The coins are assumed to be fair and the payout is the given
* value for each coin that comes up heads.
*
* @param n a positive integer
* @param v the value of each head
*/
public BinomialBandit(int n, double v)
{
this(n, v, 0.5);
}
/**
* Creates a binomial bandit that flips the given coins that have
* the given probability of coming up heads. The payout is the
* given value for each coin that comes up heads.
*
* @param n a positive integer
* @param v a positive integer
* @param p a number between 0.0 and 1.0 inclusive
*/
public BinomialBandit(int n, double v, double p)
{
// validate arguments
if (n <= 0)
{
throw new IllegalArgumentException("Number of coins must be positive: " + n);
}
if (v <= 0.0)
{
throw new IllegalArgumentException("Value must be positive: " + v);
}
if (p < 0.0 || p > 1.0)
{
throw new IllegalArgumentException("Probability must be between 0.0 and 1.0: " + p);
}
// squirrel away arguments in fields for later use
numCoins = n;
headValue = v;
pHead = p;
}
/**
* Determines the payout for one 1.0 unit play of this machine.
* The payout is randomized for each invocation of this method.
*
* @return the payout, a nonnegative number
*/
public double play()
{
// flip coins, keeping track of heads
int numHeads = 0;
for (int c = 0; c < numCoins; c++)
{
if (Math.random() < pHead)
{
numHeads++;
}
}
// compute payout
return headValue * numHeads;
}
/**
* Returns the number of coins used by this machine.
*
* @return the number of coins
*/
public int countCoins()
{
return numCoins;
}
/**
* Returns the probability of a coin used by this machine coming up heads.
*
* @return the probability of a head
*/
public double getProbability()
{
return pHead;
}
/**
* Returns the value of each coin flip this machine makes that comes
* up heads.
*
* @return the value of a head
*/
public double getHeadValue()
{
return headValue;
}
/**
* Returns a measure of how close this binomial distribution
* approximates the normal distribution with the same mean and
* standard deviation. The value is computed as smaller of
* the expected number of heads and the expected number of tails.
* Wikipedia suggests (w/o citation) that a normal approximation
* to a binomial distribution is good if this quantity is >= 5.
* The assumption here is that the same value is a good measure
* of the quality of a binomial approximation of a normal distribution.
*
* @return a measure of the how well this distribution approximates
* normal; increasing with better approximations
*/
private double computeQualityOfApproximationOfNormal()
{
return Math.min(numCoins * pHead, numCoins * (1.0 - pHead));
}
/**
* Creates a binomial bandit with a payout that is approximately
* normal. Given the mean and standard deviation of the desired
* normal distriution, this method attempts to minimize the number
* of coins used while retaining a good approximation. The
* current implementation considers an approximation good if the
* expected number of heads and the expected number of tails are
* both greater than some threshold, currently 5.0. However, the
* approximation is <I>not</I> guaranteed to meet this definition
* of good. The approximation is guaranteed to have the correct
* mean; the standard deviation may be off.
*
* @param mean a positive number
* @param dev a positive number
*/
public static OneArmedBandit makeNormalApproximation(double mean,
double dev)
{
// find n, v, p to satisfy mean = n * v * p,
// dev = v * sqrt(np(1-p)),
// and n is an integer
// linear search: start with 1 coin, compute v and p, and then
// check if np and n(1-p) are "big enough" (some sources suggest >= 5)
// (binary search would be more efficient, of course)
final double THRESHOLD = 5.0;
// find smallest power of 2 that meets threshold
int end = 2;
BinomialBandit b = makeBinomial(mean, dev, end);
while (end * 2 > 0 &&
b.computeQualityOfApproximationOfNormal() < THRESHOLD)
{
end *= 2;
b = makeBinomial(mean, dev, end);
}
// binary search between end/2 and end for 1st n that is good
int start = end / 2;
while (start < end)
{
int mid = (start + end) / 2;
b = makeBinomial(mean, dev, mid);
// System.out.println("Trying " + b + " " + b.computeQualityOfApproximationOfNormal());
if (b.computeQualityOfApproximationOfNormal() < THRESHOLD)
{
start = mid + 1;
}
else
{
end = mid - 1;
}
}
return makeBinomial(mean, dev, end);
}
/**
* Creates a binomial bandit with the given number of coins and
* the given mean payout and standard deviation of payout.
*
* @param mean a positive number
* @param dev a positive number
* @param n a positive integer
*/
public static BinomialBandit makeBinomial(double mean, double dev, int n)
{
double p = mean * mean / (n * dev * dev + mean * mean);
double v = mean / (n * p);
return new BinomialBandit(n, v, p);
}
/**
* Returns the maximum possible payout from this machine.
*
* @return the maximum payout
*/
public double getMaximum()
{
return numCoins * headValue;
}
/**
* Returns a printable representation of this machine.
*
* @return a printable representation of this machine
*/
public String toString()
{
return "BinomialBandit(n=" + numCoins + ", p=" + pHead + ", v=" + headValue + ")";
}
/**
* Test for makeNormalApproximation.
*/
public static void main(String[] args)
{
if (args.length < 2)
{
System.err.println("USAGE: java BinomialBandit mean dev");
return;
}
double mean = Double.parseDouble(args[0]);
double dev = Double.parseDouble(args[1]);
OneArmedBandit b = BinomialBandit.makeNormalApproximation(mean, dev);
double sumX = 0.0;
double sumSquare = 0.0;
final int NUM_PULLS = 100000;
for (int i = 0; i < NUM_PULLS; i++)
{
double p = b.play();
sumX += p;
sumSquare += p * p;
}
double bMean = sumX / NUM_PULLS;
double bVar = sumSquare / NUM_PULLS - mean * mean;
System.out.println("mean = " + bMean);
System.out.println("std dev = " + Math.sqrt(bVar));
}
}
import java.util.*;
/**
* A one-armed bandit that is a composite of other one-armed bandits.
* One play of a composite bandit pulls all of the arms of the constituent
* bandits. The total payout is the average of the payouts of the individual
* machines. The difference between a composite bandit and a one-armed
* bandits is that the former plays <I>all</I> arms at once and returns the
* average payout, while the former plays <I>one</I> arm at a time and
* returns that arm's complete payout.
*
* @author Jim Glenn
* @version 0.1 1/18/2009
*/
public class CompositeBandit implements OneArmedBandit
{
/**
* A list that holds the one-armed bandits that make up this machine.
*/
private List< OneArmedBandit > arms;
/**
* Creates a one-armed bandit with only one arm.
*
* @param arm a one-armed bandit
*/
public CompositeBandit(OneArmedBandit arm)
{
arms = new ArrayList< OneArmedBandit >();
arms.add(arm);
}
/**
* Creates a composite bandit with two arms.
*
* @param a0 a one-armed bandit
* @param a1 a one-armed bandit
*/
public CompositeBandit(OneArmedBandit a0, OneArmedBandit a1)
{
arms = new ArrayList< OneArmedBandit >();
arms.add(a0);
arms.add(a1);
}
/**
* Creates a composite bandit with three arms.
*
* @param a0 a one-armed bandit
* @param a1 a one-armed bandit
* @param a2 a one-armed bandit
*/
public CompositeBandit(OneArmedBandit a0,
OneArmedBandit a1,
OneArmedBandit a2)
{
arms = new ArrayList< OneArmedBandit >();
arms.add(a0);
arms.add(a1);
}
/**
* Creates a composite bandit from the one-armed bandits in the given
* list. The indices into the list will be the indices of the arms
* in the new machine.
*
* @param l a list of one-armed bandits
*/
public CompositeBandit(List< OneArmedBandit > l)
{
// make a deep copy of the list
arms = new ArrayList< OneArmedBandit >(l);
}
/**
* Creates a composite bandit from the one-armed bandits in the given
* array. The indices into the array will be the indices of the arms
* in the new machine.
*
* @param a an array of one-armed bandits
*/
public CompositeBandit(OneArmedBandit[] a)
{
// make a deep copy of the array; note that the list returned by
// as list is not a deep copy of a
arms = new ArrayList< OneArmedBandit >(Arrays.asList(a));
}
/**
* Returns the result of playing the given arm on this machine. Each
* invocation of this method represents a different play. Arms are
* given by consecutive integer indices that start at 0.
*
* @param i the index of an arm on this machine.
*/
public double play()
{
double total = 0.0;
for (OneArmedBandit arm : arms)
{
total += arm.play();
}
return total / arms.size();
}
/**
* Returns the maximum possible payout from this machine.
*
* @return the maximum payout
*/
public double getMaximum()
{
double total = 0.0;
for (OneArmedBandit arm : arms)
{
total += arm.getMaximum();
}
return total / arms.size();
}
}
/**
* A multi-armed bandit player that chooses which arm to play randomly.
*
* @author Jim Glenn
* @version 0.1 1/18/2009
*/
public class RandomPlayer extends MultiArmedBanditPlayer
{
public RandomPlayer(MultiArmedBandit m)
{
super(m);
}
/**
* Returns the arm that this player chooses to play. This implementation
* chooses the arm randomly.
*
* @return the index of the arm this player chooses
*/
public int chooseArm()
{
return (int)(machine.countArms() * Math.random());
}
/**
* Updates this player with the outcome of pulling the given arm.
* Subclasses should override this method to keep track of the history
* of the machine in order to determine which arm to play.
* This implementation does nothing.
*
* @param arm the index of the arm played
* @param payout the amount earned on that play
*/
public void update(int arm, double payout)
{
}
}
This code can also be downloaded from the files
MultiArmedBanditPlayer.java,
MultiArmedBanditSimulation.java,
MultiArmedBandit.java,
OneArmedBandit.java,
BinomialBandit.java,
CompositeBandit.java,
and RandomPlayer.java.