# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Synthetic

# %%
from vectorbtpro import *
from scipy.stats import levy_stable

def geometric_levy_price(alpha, beta, drift, vol, shape):
    _rvs = levy_stable.rvs(alpha, beta,loc=0, scale=1, size=shape)
    _rvs_sum = np.cumsum(_rvs, axis=0)
    return np.exp(vol * _rvs_sum + (drift - 0.5 * vol ** 2))

class LevyData(vbt.SyntheticData):

    _settings_path = dict(custom="data.custom.levy")

    @classmethod
    def generate_key(
        cls,
        key,
        index,
        columns,
        start_value=None,
        alpha=None,
        beta=None,
        drift=None,
        vol=None,
        seed=None,
        **kwargs
    ):
        start_value = cls.resolve_custom_setting(start_value, "start_value")
        alpha = cls.resolve_custom_setting(alpha, "alpha")
        beta = cls.resolve_custom_setting(beta, "beta")
        drift = cls.resolve_custom_setting(drift, "drift")
        vol = cls.resolve_custom_setting(vol, "vol")
        seed = cls.resolve_custom_setting(seed, "seed")
        if seed is not None:
            np.random.seed(seed)

        shape = (len(index), len(columns))
        out = geometric_levy_price(alpha, beta, drift, vol, shape)
        out = start_value * out
        return pd.DataFrame(out, index=index, columns=columns)

LevyData.set_custom_settings(
    populate_=True,
    start_value=100.,
    alpha=1.68,
    beta=0.01,
    drift=0.0,
    vol=0.01,
    seed=None
)

# %%
levy_data = LevyData.pull(
    "Close",
    keys_are_features=True,
    columns=pd.Index(["BTC/USD", "ETH/USD", "XRP/USD"], name="symbol"),
    start="2020-01-01 UTC",
    end="2021-01-01 UTC",
    seed=42)
levy_data.get()

# %%
levy_data.get().vbt.plot().show()

# %%