Spaces:
Runtime error
Runtime error
File size: 4,364 Bytes
3380ee9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""Helper file for Thompson sampling"""
import pickle
import random
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import config as cfg
random.seed(42)
class ThompsonSampler:
def __init__(self):
self.placeholder = st.empty()
self.latent_elasticity = cfg.LATENT_ELASTICITY
self.price_observations = np.concatenate(
[np.repeat(10,10), np.repeat(7.5,25), np.repeat(11,15)]
)
self.update_demand_observations()
self.possible_prices = np.linspace(0, 20, 100)
self.price_samples = []
self.latent_demand = self.calc_latent_demand()
self.latent_price = self.calc_optimal_price(self.latent_demand, sample=False)
self.update_posteriors()
def update_demand_observations(self):
self.demand_observations = np.exp(
np.random.normal(
loc=-self.latent_elasticity*self.price_observations+cfg.LATENT_SHAPE,
scale=cfg.LATENT_STDEV,
)
)
def update_elasticity(self):
self.latent_elasticity = st.session_state.latent_elasticity
self.price_samples = []
self.latent_demand = self.calc_latent_demand()
self.update_demand_observations()
self.latent_price = self.calc_optimal_price(self.latent_demand, sample=False)
self.update_posteriors(samples=75)
self.create_plots()
def create_plots(self, highlighted_sample=None):
with self.placeholder.container():
posterior_plot, price_plot = st.columns(2)
with posterior_plot:
st.markdown("## Demands")
fig = self.create_posteriors_plot(highlighted_sample)
st.write(fig)
plt.close(fig)
with price_plot:
st.markdown("## Prices")
fig = self.create_price_plot()
st.write(fig)
plt.close(fig)
def create_price_plot(self):
fig = plt.figure()
plt.xlabel("Price")
plt.yticks(color='w')
price_distr = [self.calc_optimal_price(post_demand, sample=False)
for post_demand in self.posterior]
plt.violinplot(price_distr, vert=False, showextrema=False)
for price in self.price_samples:
plt.plot(price, 1, marker='o', markersize = 5, color='grey')
plt.axhline(1, color='black')
plt.axvline(self.latent_price, 0, color='red')
return fig
def create_posteriors_plot(self, highlighted_sample=None):
fig = plt.figure()
plt.xlabel("Price")
plt.ylabel("Demand")
plt.xlim(0,20)
plt.ylim(0,10)
plt.scatter(self.price_observations, self.demand_observations)
plt.plot(self.possible_prices, self.latent_demand, color="red")
for posterior_sample in self.posterior_samples:
plt.plot(self.possible_prices, posterior_sample, color="grey", alpha=0.15)
if highlighted_sample is not None:
plt.plot(self.possible_prices, highlighted_sample, color="black")
return fig
def calc_latent_demand(self):
return np.exp(
-self.latent_elasticity*self.possible_prices + cfg.LATENT_SHAPE
)
@staticmethod
@np.vectorize
def _cost(demand):
return cfg.VARIABLE_COST*demand + cfg.FIXED_COST
def calc_optimal_price(self, sampled_demand, sample=False):
revenue = self.possible_prices * sampled_demand
profit = revenue - self._cost(sampled_demand)
optimal_price = self.possible_prices[np.argmax(profit)]
if sample:
self.price_samples.append(optimal_price)
return optimal_price
def update_posteriors(self, samples=75):
with open(f"assets/precalc_results/posterior_{self.latent_elasticity}.pkl", "rb") as post:
self.posterior = pickle.load(post)
self.posterior_samples = random.sample(self.posterior, samples)
def pick_posterior(self):
posterior_sample = random.choice(self.posterior_samples)
self.calc_optimal_price(posterior_sample, sample=True)
self.create_plots(highlighted_sample=posterior_sample)
def run(self):
if st.session_state.latent_elasticity != self.latent_elasticity:
self.update_elasticity()
self.pick_posterior()
|