# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Splitter

# %%
splitter = vbt.Splitter.from_rolling(
    data.index,
    length=360,
    split=0.5,
    set_labels=["IS", "OOS"]
)
splitter.plot().show()

# %% [markdown]
# ## Schema

# %%
splitter.splits

# %%
splitter.index

# %%
splitter.wrapper.index

# %%
splitter.wrapper.columns

# %%
oos_splitter = splitter["OOS"]
oos_splitter.splits

# %% [markdown]
# ### Range format

# %%
index = pd.date_range("2020", periods=14)
index[slice(1, 7)]

# %%
index[1], index[6]

# %% [markdown]
# #### Relative

# %%
rel_range = vbt.RelRange(offset=10, length=40)
rel_range

# %%
rel_range.to_slice(total_len=len(splitter.index), prev_end=100)

# %% [markdown]
# ### Array format

# %%
index = pd.date_range("2020", "2021", freq="1min")
range_ = np.arange(len(index))
range_.nbytes / 1024 / 1024

# %%
range_ = np.full(len(index), True)
range_.nbytes / 1024 / 1024

# %%
splitter.splits_arr.dtype

# %%
id(slice(0, 180, None))

# %%
range_00 = np.arange(0, 5)
range_01 = np.arange(5, 15)
range_10 = np.arange(15, 30)
range_11 = np.arange(30, 50)

ind_splitter = vbt.Splitter.from_splits(
    data.index,
    [[range_00, range_01], [range_10, range_11]],
    fix_ranges=False
)
ind_splitter.splits

# %%
ind_splitter.splits.loc[0, "set_1"]

# %%
ind_splitter.splits.loc[0, "set_1"].range_

# %% [markdown]
# ## Preparation
# ### Splits

# %%
vbt.Splitter.split_range(
    slice(None),
    (vbt.RelRange(length=0.75), vbt.RelRange()),
    index=data.index
)

# %%
splitter.split_range(
    slice(None),
    (vbt.RelRange(length=0.75), vbt.RelRange())
)

# %%
data[slice(0, 1426, None)]

# %%
vbt.Splitter.split_range(
    slice(None),
    0.75,
    index=data.index
)

# %%
vbt.Splitter.split_range(
    slice(None),
    -0.25,
    index=data.index
)

# %%
int(0.75 * len(data.index))

# %%
len(data.index) - int(0.25 * len(data.index))

# %%
vbt.Splitter.split_range(
    slice(None),
    (vbt.RelRange(), vbt.RelRange(length=0.25)),
    backwards=True,
    index=data.index
)

# %%
vbt.Splitter.split_range(
    slice(None),
    (1.0, 30),
    backwards=True,
    index=data.index
)

# %%
vbt.Splitter.split_range(
    slice(None),
    (
        vbt.RelRange(length=0.4, length_space="all"),
        vbt.RelRange(length=0.4, length_space="all"),
        vbt.RelRange()
    ),
    index=data.index
)

# %%
vbt.Splitter.split_range(
    slice(None),
    (vbt.RelRange(length=0.75), vbt.RelRange(offset=1)),
    index=data.index
)

# %%
vbt.Splitter.split_range(
    slice(None),
    (
        vbt.RelRange(length=0.75),
        vbt.RelRange(length=1, is_gap=True),
        vbt.RelRange()
    ),
    index=data.index
)

# %%
vbt.Splitter.split_range(
    slice(None),
    (np.array([3, 4, 5]), np.array([6, 8, 10])),
    index=data.index
)

# %%
vbt.Splitter.split_range(
    slice(None),
    (np.array([3, 4, 5]), np.array([6, 8, 10])),
    range_format="indices",
    index=data.index
)

# %%
vbt.Splitter.split_range(
    slice(None),
    (slice("2020", "2021"), slice("2021", "2022")),
    index=data.index
)

# %%
data.index[867:1233]

# %%
data.index[1233:1598]

# %%
vbt.Splitter.split_range(
    slice(None),
    (
        vbt.RelRange(length="180 days"),
        vbt.RelRange(offset="1 day", length="90 days")
    ),
    index=data.index
)

# %% [markdown]
# ### Method

# %%
manual_splitter = vbt.Splitter.from_splits(
    data.index,
    [
        (vbt.RelRange(), vbt.RelRange(offset=0.5, length=0.25, length_space="all")),
        (vbt.RelRange(), vbt.RelRange(offset=0.25, length=0.25, length_space="all")),
        (vbt.RelRange(), vbt.RelRange(offset=0, length=0.25, length_space="all")),
    ],
    split_range_kwargs=dict(backwards=True),
    set_labels=["IS", "OOS"]
)
manual_splitter.splits

# %%
manual_splitter.plot().show()

# %% [markdown]
# ## Generation
# ### Rolling

# %%
vbt.Splitter.from_rolling(
    data.index,
    length=360,
).plot().show()

# %%
vbt.Splitter.from_rolling(
    data.index,
    length=360,
    offset=90
).plot().show()

# %%
vbt.Splitter.from_rolling(
    data.index,
    length=360,
    offset=-0.5
).plot().show()

# %%
vbt.Splitter.from_rolling(
    data.index,
    length=360,
    split=0.5
).plot().show()

# %%
vbt.Splitter.from_rolling(
    data.index,
    length=360,
    split=0.5,
    offset_anchor_set=None
).plot().show()

# %%
vbt.Splitter.from_n_rolling(
    data.index,
    n=5,
    split=0.5
).plot().show()

# %%
vbt.Splitter.from_n_rolling(
    data.index,
    n=3,
    length=360,
    split=0.5
).plot().show()

# %%
vbt.Splitter.from_n_rolling(
    data.index,
    n=7,
    length=360,
    split=0.5
).plot().show()

# %%
vbt.Splitter.from_expanding(
    data.index,
    min_length=360,
    offset=180,
    split=-180
).plot().show()

# %%
vbt.Splitter.from_n_expanding(
    data.index,
    n=5,
    min_length=360,
    split=-180
).plot().show()

# %% [markdown]
# ### Anchored

# %%
vbt.Splitter.from_ranges(
    data.index,
    every="AS",
    split=0.5
).plot().show()

# %%
vbt.Splitter.from_ranges(
    data.index,
    every="Q",
    closed_end=True,
    lookback_period="AS",
    split=0.5
).plot().show()

# %%
vbt.Splitter.from_ranges(
    data.index,
    every="Q",
    closed_end=True,
    lookback_period="AS",
    split=(
        vbt.RepEval("index.month != index.month[-1]"),
        vbt.RepEval("index.month == index.month[-1]")
    )
).plot().show()

# %%
def qyear(index):
    return index.to_period("Q")

vbt.Splitter.from_ranges(
    data.index,
    start=0,
    fixed_start=True,
    every="Q",
    closed_end=True,
    split=(
        lambda index: qyear(index) != qyear(index)[-1],
        lambda index: qyear(index) == qyear(index)[-1]
    )
).plot().show()

# %%
vbt.Splitter.from_grouper(
    data.index,
    by="AS",
    split=0.5
).plot().show()

# %%
def is_split_complete(index, split):
    first_range = split[0]
    first_index = index[first_range][0]
    last_range = split[-1]
    last_index = index[last_range][-1]
    return first_index.is_year_start and last_index.is_year_end

vbt.Splitter.from_grouper(
    data.index,
    by="AS",
    split=0.5,
    split_check_template=vbt.RepFunc(is_split_complete)
).plot().show()

# %%
def format_split_labels(index, splits_arr):
    years = map(lambda x: index[x[0]][0].year, splits_arr)
    return pd.Index(years, name="split_year")

vbt.Splitter.from_grouper(
    data.index,
    by="AS",
    split=0.5,
    split_check_template=vbt.RepFunc(is_split_complete),
    split_labels=vbt.RepFunc(format_split_labels)
).plot().show()

# %%
vbt.Splitter.from_grouper(
    data.index,
    by=data.index.year,
    split=0.5,
    split_check_template=vbt.RepFunc(is_split_complete)
).plot().show()

# %% [markdown]
# ### Random

# %%
vbt.Splitter.from_n_random(
    data.index,
    n=50,
    min_length=360,
    seed=42,
    split=0.5
).plot().show()

# %%
vbt.Splitter.from_n_random(
    data.index,
    n=50,
    min_length=30,
    max_length=300,
    seed=42,
    split=0.5
).plot().show()

# %%
def start_p_func(i, indices):
    return indices / indices.sum()

vbt.Splitter.from_n_random(
    data.index,
    n=50,
    min_length=30,
    max_length=300,
    seed=42,
    start_p_func=start_p_func,
    split=0.5
).plot().show()

# %% [markdown]
# ### Scikit-learn

# %%
from sklearn.model_selection import KFold

vbt.Splitter.from_sklearn(
    data.index,
    splitter=KFold(n_splits=5)
).plot().show()

# %% [markdown]
# ### Dynamic

# %%
def split_func(index, prev_start):
    if prev_start is None:
        prev_start = index[0]
    new_start = prev_start + pd.offsets.MonthBegin(1)
    new_end = new_start + pd.DateOffset(years=1)
    if new_end > index[-1] + index.freq:
        return None
    return [
        slice(new_start, new_start + pd.offsets.MonthBegin(9)),
        slice(new_start + pd.offsets.MonthBegin(9), new_end)
    ]

vbt.Splitter.from_split_func(
    data.index,
    split_func=split_func,
    split_args=(vbt.Rep("index"), vbt.Rep("prev_start")),
    range_bounds_kwargs=dict(index_bounds=True)
).plot().show()

# %%
def get_next_monday(from_date):
    if from_date.weekday == 0 and from_date.ceil("H").hour <= 9:
        return from_date.floor("D")
    return from_date.floor("D") + pd.offsets.Week(n=0, weekday=0)

def get_next_business_range(from_date):
    monday_0000 = get_next_monday(from_date)
    monday_0900 = monday_0000 + pd.DateOffset(hours=9)
    friday_1700 = monday_0900 + pd.DateOffset(days=4, hours=8)
    return slice(monday_0900, friday_1700)

def split_func(index, bounds):
    if len(bounds) == 0:
        from_date = index[0]
    else:
        from_date = bounds[-1][1][0]
    train_range = get_next_business_range(from_date)
    test_range = get_next_business_range(train_range.stop)
    if test_range.stop > index[-1] + index.freq:
        return None
    return train_range, test_range

vbt.Splitter.from_split_func(
    pd.date_range("2020-01", "2020-03", freq="15min"),
    split_func=split_func,
    split_args=(vbt.Rep("index"), vbt.Rep("bounds")),
    range_bounds_kwargs=dict(index_bounds=True)
).plot().show()

# %% [markdown]
# ## Validation

# %%
splitter = vbt.Splitter.from_ranges(
    data.index,
    every="AS",
    closed_end=True,
    split=0.5,
    set_labels=["IS", "OOS"]
)
splitter.plot().show()

# %% [markdown]
# ### Bounds

# %%
bounds_arr = splitter.get_bounds_arr()
bounds_arr.shape

# %%
bounds_arr

# %%
bounds = splitter.get_bounds(index_bounds=True)
bounds.shape

# %%
bounds

# %%
bounds.loc[(0, "OOS"), "end"]

# %%
bounds.loc[(1, "IS"), "start"]

# %% [markdown]
# ### Masks

# %%
mask = splitter.get_mask()
mask.shape

# %%
mask

# %%
mask["2021":"2021"].any()

# %%
mask.resample("AS").sum()

# %%
results = []
for mask in splitter.get_iter_split_masks():
    results.append(mask.resample("AS").sum())
pd.concat(results, axis=1, keys=splitter.split_labels)

# %% [markdown]
# ### Coverage

# %%
splitter.get_split_coverage()

# %%
splitter.get_set_coverage()

# %%
splitter.get_range_coverage()

# %%
splitter.get_coverage()

# %%
splitter.index_bounds.loc[(2, "OOS"), "start"].is_leap_year

# %%
splitter.get_range_coverage(relative=True)

# %%
splitter.get_set_coverage(relative=True)

# %%
splitter.get_split_coverage(overlapping=True)

# %%
splitter.get_set_coverage(overlapping=True)

# %%
splitter.get_coverage(overlapping=True)

# %%
splitter.plot_coverage().show()

# %%
splitter.get_overlap_matrix(by="range", normalize=False)

# %% [markdown]
# ### Grouping

# %%
splitter.get_bounds(index_bounds=True, set_group_by=True)

# %% [markdown]
# ## Manipulation

# %%
splitter = vbt.Splitter.from_grouper(
    data.index,
    by=data.index.year.rename("split_year")
)

# %%
splitter.stats()

# %%
splitter.plots().show()

# %%
splitter = splitter.iloc[1:-1]
splitter.stats()

# %%
def new_split(index):
    return [
        np.isin(index.quarter, [1, 2]),
        index.quarter == 3,
        index.quarter == 4
    ]

splitter = splitter.split_set(
    vbt.RepFunc(new_split),
    new_set_labels=["train", "valid", "test"]
)

# %%
splitter.stats()

# %%
splitter.plots().show()

# %%