# ################################## HOW TO USE #################################### #
#                                                                                    #
# This is a Jupyter notebook formatted as a script                                   #
# Format: https://jupytext.readthedocs.io/en/latest/formats.html#the-percent-format  #
#                                                                                    #
# Save this file and remove the '.txt' extension                                     #
# In Jupyter Lab, right click on the Python file -> Open With -> Jupytext Notebook   #
# Make sure to have Jupytext installed: https://github.com/mwouts/jupytext           #
#                                                                                    #
# ################################################################################## #

# %% [markdown]
# #  Dynamic

# %%
GroupMemory = namedtuple("GroupMemory", [
    "target_alloc",
    "size_type",
    "direction",
    "order_value_out"
])

@njit
def pre_group_func_nb(c):
    group_memory = GroupMemory(
        target_alloc=np.full(c.group_len, np.nan),
        size_type=np.full(c.group_len, vbt.pf_enums.SizeType.TargetPercent),
        direction=np.full(c.group_len, vbt.pf_enums.Direction.Both),
        order_value_out=np.full(c.group_len, np.nan)
    )
    return group_memory,

@njit
def pre_segment_func_nb(
    c,
    group_memory,
    min_history,
    threshold,
    allocate_func_nb,
    *args
):
    should_rebalance = False

    if c.i >= min_history:
        in_position = False
        for col in range(c.from_col, c.to_col):
            if c.last_position[col] != 0:
                in_position = True
                break

        if not in_position:
            should_rebalance = True
        else:
            curr_value = c.last_value[c.group]
            for group_col in range(c.group_len):
                col = c.from_col + group_col
                curr_position = c.last_position[col]
                curr_price = c.last_val_price[col]
                curr_alloc = curr_position * curr_price / curr_value
                curr_threshold = vbt.pf_nb.select_from_col_nb(c, col, threshold)
                alloc_diff = curr_alloc - group_memory.target_alloc[group_col]

                if abs(alloc_diff) >= curr_threshold:
                    should_rebalance = True
                    break

    if should_rebalance:
        allocate_func_nb(c, group_memory, *args)
        vbt.pf_nb.sort_call_seq_1d_nb(
            c,
            group_memory.target_alloc,
            group_memory.size_type,
            group_memory.direction,
            group_memory.order_value_out
        )

    return group_memory, should_rebalance

@njit
def order_func_nb(
    c,
    group_memory,
    should_rebalance,
    price,
    fees
):
    if not should_rebalance:
        return vbt.pf_nb.order_nothing_nb()

    group_col = c.col - c.from_col
    return vbt.pf_nb.order_nb(
        size=group_memory.target_alloc[group_col],
        price=vbt.pf_nb.select_nb(c, price),
        size_type=group_memory.size_type[group_col],
        direction=group_memory.direction[group_col],
        fees=vbt.pf_nb.select_nb(c, fees)
    )

# %%
@njit
def uniform_allocate_func_nb(c, group_memory):
    for group_col in range(c.group_len):
        group_memory.target_alloc[group_col] = 1 / c.group_len

# %%
def simulate_threshold_rebalancing(threshold, allocate_func_nb, *args, **kwargs):
    return vbt.Portfolio.from_order_func(
        data.get("Close"),
        open=data.get("Open"),
        pre_group_func_nb=pre_group_func_nb,
        pre_group_args=(),
        pre_segment_func_nb=pre_segment_func_nb,
        pre_segment_args=(
            0,
            vbt.Rep("threshold"),
            allocate_func_nb,
            *args
        ),
        order_func_nb=order_func_nb,
        order_args=(vbt.Rep('price'), vbt.Rep('fees')),
        broadcast_named_args=dict(
            price=data.get("Close"),
            fees=0.005,
            threshold=threshold
        ),
        cash_sharing=True,
        group_by=vbt.ExceptLevel("symbol"),
        freq='1h',
        **kwargs
    )

pf = simulate_threshold_rebalancing(0.05, uniform_allocate_func_nb)
pf.plot_allocations().show()

# %%
import os

os.environ["NUMBA_BOUNDSCHECK"] = "1"
os.environ["NUMBA_DISABLE_JIT"] = "1"

# %%
pf = simulate_threshold_rebalancing(
    vbt.Param(np.arange(1, 16) / 100, name="threshold"),
    uniform_allocate_func_nb
)

pf.sharpe_ratio

# %% [markdown]
# ## Post-analysis

# %%
@njit
def track_uniform_allocate_func_nb(c, group_memory, index_points, alloc_counter):
    for group_col in range(c.group_len):
        group_memory.target_alloc[group_col] = 1 / c.group_len

    index_points[alloc_counter[0]] = c.i
    alloc_counter[0] += 1

index_points = np.empty(data.wrapper.shape[0], dtype=np.int_)
alloc_counter = np.full(1, 0)
pf = simulate_threshold_rebalancing(
    0.05,
    track_uniform_allocate_func_nb,
    index_points,
    alloc_counter
)
index_points = index_points[:alloc_counter[0]]

data.wrapper.index[index_points]

# %%
@njit
def random_allocate_func_nb(
    c,
    group_memory,
    alloc_points,
    alloc_weights,
    alloc_counter
):
    weights = np.random.uniform(0, 1, c.group_len)
    group_memory.target_alloc[:] = weights / weights.sum()

    group_count = alloc_counter[c.group]
    count = alloc_counter.sum()
    alloc_points["id"][count] = group_count
    alloc_points["col"][count] = c.group
    alloc_points["alloc_idx"][count] = c.i
    alloc_weights[count] = group_memory.target_alloc
    alloc_counter[c.group] += 1

thresholds = pd.Index(np.arange(1, 16) / 100, name="threshold")
max_entries = data.wrapper.shape[0] * len(thresholds)
alloc_points = np.empty(max_entries, dtype=vbt.pf_enums.alloc_point_dt)
alloc_weights = np.empty((max_entries, len(data.symbols)), dtype=np.float_)
alloc_counter = np.full(len(thresholds), 0)

pf = simulate_threshold_rebalancing(
    vbt.Param(thresholds),
    random_allocate_func_nb,
    alloc_points,
    alloc_weights,
    alloc_counter,
    seed=42
)
alloc_points = alloc_points[:alloc_counter.sum()]
alloc_weights = alloc_weights[:alloc_counter.sum()]

# %%
@njit
def random_allocate_func_nb(c, group_memory):
    weights = np.random.uniform(0, 1, c.group_len)
    group_memory.target_alloc[:] = weights / weights.sum()

    group_count = c.in_outputs.alloc_counter[c.group]
    count = c.in_outputs.alloc_counter.sum()
    c.in_outputs.alloc_points["id"][count] = group_count
    c.in_outputs.alloc_points["col"][count] = c.group
    c.in_outputs.alloc_points["alloc_idx"][count] = c.i
    c.in_outputs.alloc_weights[count] = group_memory.target_alloc
    c.in_outputs.alloc_counter[c.group] += 1

alloc_points = vbt.RepEval("""
    max_entries = target_shape[0] * len(group_lens)
    np.empty(max_entries, dtype=alloc_point_dt)
""", context=dict(alloc_point_dt=vbt.pf_enums.alloc_point_dt))
alloc_weights = vbt.RepEval("""
    max_entries = target_shape[0] * len(group_lens)
    np.empty((max_entries, n_cols), dtype=np.float_)
""", context=dict(n_cols=len(data.symbols)))
alloc_counter = vbt.RepEval("np.full(len(group_lens), 0)")

InOutputs = namedtuple("InOutputs", [
    "alloc_points",
    "alloc_weights",
    "alloc_counter"
])
in_outputs = InOutputs(
    alloc_points=alloc_points,
    alloc_weights=alloc_weights,
    alloc_counter=alloc_counter,
)

pf = simulate_threshold_rebalancing(
    vbt.Param(np.arange(1, 16) / 100, name="threshold"),
    random_allocate_func_nb,
    in_outputs=in_outputs,
    seed=42
)
alloc_points = pf.in_outputs.alloc_points[:pf.in_outputs.alloc_counter.sum()]
alloc_weights = pf.in_outputs.alloc_weights[:pf.in_outputs.alloc_counter.sum()]

# %%
pfo = vbt.PortfolioOptimizer(
    wrapper=pf.wrapper,
    alloc_records=vbt.AllocPoints(
        pf.wrapper.resolve(),
        alloc_points
    ),
    allocations=alloc_weights
)

# %%
pfo[0.1].allocations.describe()

# %%
pfo.plot(column=0.1).show()

# %%
pfo.plot(column=0.03).show()

# %%
pf[0.03].plot_allocations().show()

# %%
pf.sharpe_ratio

# %%
pf_new = vbt.Portfolio.from_optimizer(
    data,
    pfo,
    val_price=data.get("Open"),
    freq="1h",
    fees=0.005
)

pf_new.sharpe_ratio

# %% [markdown]
# ## Bonus 1: Own optimizer

# %%
@njit(nogil=True)
def optimize_portfolio_nb(
    close,
    val_price,
    range_starts,
    range_ends,
    optimize_func_nb,
    optimize_args=(),
    price=np.inf,
    fees=0.,
    init_cash=100.,
    group=0
):
    val_price_ = vbt.to_2d_array_nb(np.asarray(val_price))
    price_ = vbt.to_2d_array_nb(np.asarray(price))
    fees_ = vbt.to_2d_array_nb(np.asarray(fees))

    order_records = np.empty(close.shape, dtype=vbt.pf_enums.order_dt)
    order_counts = np.full(close.shape[1], 0, dtype=np.int_)

    order_value = np.empty(close.shape[1], dtype=np.float_)
    call_seq = np.empty(close.shape[1], dtype=np.int_)

    last_position = np.full(close.shape[1], 0.0, dtype=np.float_)
    last_debt = np.full(close.shape[1], 0.0, dtype=np.float_)
    last_locked_cash = np.full(close.shape[1], 0.0, dtype=np.float_)
    cash_now = float(init_cash)
    free_cash_now = float(init_cash)
    value_now = float(init_cash)

    for k in range(len(range_starts)):
        i = range_ends[k]
        size = optimize_func_nb(
            range_starts[k],
            range_ends[k],
            *optimize_args
        )


        value_now = cash_now
        for col in range(close.shape[1]):
            val_price_now = vbt.flex_select_nb(val_price_, i, col)
            value_now += last_position[col] * val_price_now

        for col in range(close.shape[1]):
            val_price_now = vbt.flex_select_nb(val_price_, i, col)
            exec_state = vbt.pf_enums.ExecState(
                cash=cash_now,
                position=last_position[col],
                debt=last_debt[col],
                locked_cash=last_locked_cash[col],
                free_cash=free_cash_now,
                val_price=val_price_now,
                value=value_now,
            )
            order_value[col] = vbt.pf_nb.approx_order_value_nb(
                exec_state,
                size[col],
                vbt.pf_enums.SizeType.TargetPercent,
                vbt.pf_enums.Direction.Both,
            )
            call_seq[col] = col

        vbt.pf_nb.insert_argsort_nb(order_value, call_seq)

        for c in range(close.shape[1]):
            col = call_seq[c]

            order = vbt.pf_nb.order_nb(
                size=size[col],
                price=vbt.flex_select_nb(price_, i, col),
                size_type=vbt.pf_enums.SizeType.TargetPercent,
                direction=vbt.pf_enums.Direction.Both,
                fees=vbt.flex_select_nb(fees_, i, col),
            )


            price_area = vbt.pf_enums.PriceArea(
                open=np.nan,
                high=np.nan,
                low=np.nan,
                close=vbt.flex_select_nb(close, i, col),
            )
            val_price_now = vbt.flex_select_nb(val_price_, i, col)
            exec_state = vbt.pf_enums.ExecState(
                cash=cash_now,
                position=last_position[col],
                debt=last_debt[col],
                locked_cash=last_locked_cash[col],
                free_cash=free_cash_now,
                val_price=val_price_now,
                value=value_now,
            )
            _, new_exec_state = vbt.pf_nb.process_order_nb(
                group=group,
                col=col,
                i=i,
                exec_state=exec_state,
                order=order,
                price_area=price_area,
                order_records=order_records,
                order_counts=order_counts
            )

            cash_now = new_exec_state.cash
            free_cash_now = new_exec_state.free_cash
            value_now = new_exec_state.value
            last_position[col] = new_exec_state.position
            last_debt[col] = new_exec_state.debt
            last_locked_cash[col] = new_exec_state.locked_cash


    return vbt.nb.repartition_nb(order_records, order_counts)

# %%
@njit(nogil=True)
def sharpe_optimize_func_nb(
    start_idx,
    end_idx,
    close,
    num_tests,
    ann_factor
):
    close_period = close[start_idx:end_idx]
    returns = (close_period[1:] - close_period[:-1]) / close_period[:-1]
    mean = vbt.nb.nanmean_nb(returns)
    cov = np.cov(returns, rowvar=False)
    best_sharpe_ratio = -np.inf
    weights = np.full(close.shape[1], np.nan, dtype=np.float_)

    for i in range(num_tests):
        w = np.random.random_sample(close.shape[1])
        w = w / np.sum(w)
        p_return = np.sum(mean * w) * ann_factor
        p_std = np.sqrt(np.dot(w.T, np.dot(cov, w))) * np.sqrt(ann_factor)
        sharpe_ratio = p_return / p_std
        if sharpe_ratio > best_sharpe_ratio:
            best_sharpe_ratio = sharpe_ratio
            weights = w

    return weights

# %%
range_starts, range_ends = data.wrapper.get_index_ranges(every="W")
ann_factor = pd.Timedelta("365d") / pd.Timedelta("1h")
init_cash = 100
num_tests = 30
fees = 0.005

order_records = optimize_portfolio_nb(
    data.get("Close").values,
    data.get("Open").values,
    range_starts,
    range_ends,
    sharpe_optimize_func_nb,
    optimize_args=(data.get("Close").values, num_tests, ann_factor),
    fees=fees,
    init_cash=init_cash
)

# %%
pf = vbt.Portfolio(
    wrapper=symbol_wrapper.regroup(True),
    close=data.get("Close"),
    order_records=order_records,
    log_records=np.array([]),
    cash_sharing=True,
    init_cash=init_cash
)

# %%
pf.plot_allocations().show()

# %% [markdown]
# ## Bonus 2: Parameterization

# %%
def merge_func(order_records_list, param_index):
    sharpe_ratios = pd.Series(index=param_index, dtype=np.float_)
    for i, order_records in enumerate(order_records_list):
        pf = vbt.Portfolio(
            wrapper=symbol_wrapper.regroup(True),
            close=data.get("Close"),
            order_records=order_records,
            cash_sharing=True,
            init_cash=init_cash
        )
        sharpe_ratios[i] = pf.sharpe_ratio
    return sharpe_ratios

# %%
param_optimize_portfolio_nb = vbt.parameterized(
    optimize_portfolio_nb,
    merge_func=merge_func,
    merge_kwargs=dict(param_index=vbt.Rep("param_index")),
    engine="dask",
    chunk_len=4
)

# %%
every_index = pd.Index(["D", "W", "MS"], name="every")
num_tests_index = pd.Index([30, 50, 100], name="num_tests")
fees_index = pd.Index([0.0, 0.005, 0.01], name="fees")

range_starts = []
range_ends = []
for every in every_index:
    index_ranges = symbol_wrapper.get_index_ranges(every=every)
    range_starts.append(index_ranges[0])
    range_ends.append(index_ranges[1])
num_tests = num_tests_index.tolist()

range_starts = vbt.Param(range_starts, level=0, keys=every_index)
range_ends = vbt.Param(range_ends, level=0, keys=every_index)
num_tests = vbt.Param(num_tests, level=1, keys=num_tests_index)
fees = vbt.Param(fees_index.values, level=2, keys=fees_index)

# %%
sharpe_ratios = param_optimize_portfolio_nb(
    data.get("Close").values,
    data.get("Open").values,
    range_starts,
    range_ends,
    sharpe_optimize_func_nb,
    optimize_args=(
        data.get("Close").values,
        num_tests,
        ann_factor
    ),
    fees=fees,
    init_cash=init_cash,
    group=vbt.Rep("config_idx")
)

# %%
sharpe_ratios

# %% [markdown]
# ## Bonus 3: Hyperopt

# %%
def objective(kwargs):
    close_values = data.get("Close").values
    open_values = data.get("Open").values
    index_ranges = symbol_wrapper.get_index_ranges(every=kwargs["every"])
    order_records = optimize_portfolio_nb(
        close_values,
        open_values,
        index_ranges[0],
        index_ranges[1],
        sharpe_optimize_func_nb,
        optimize_args=(close_values, kwargs["num_tests"], ann_factor),
        fees=vbt.to_2d_array(kwargs["fees"]),
        init_cash=init_cash
    )
    pf = vbt.Portfolio(
        wrapper=symbol_wrapper.regroup(True),
        close=data.get("Close"),
        order_records=order_records,
        log_records=np.array([]),
        cash_sharing=True,
        init_cash=init_cash
    )
    return -pf.sharpe_ratio

# %%
from hyperopt import fmin, tpe, hp

space = {
    "every": hp.choice("every", ["%dD" % n for n in range(1, 100)]),
    "num_tests": hp.quniform("num_tests", 5, 100, 1),
    "fees": hp.uniform('fees', 0, 0.05)
}

# %%
best = fmin(
    fn=objective,
    space=space,
    algo=tpe.suggest,
    max_evals=30
)

# %%
best

# %% [markdown]
# ## Bonus 4: Hybrid

# %%
def optimize_func(
    data,
    index_slice,
    temp_allocations,
    temp_pfs,
    threshold
):
    sub_data = data.iloc[index_slice]
    if len(temp_allocations) > 0:
        prev_allocation = sub_data.symbol_wrapper.wrap(
            [temp_allocations[-1]],
            index=sub_data.wrapper.index[[0]]
        )
        prev_pfo = vbt.PortfolioOptimizer.from_allocations(
            sub_data.symbol_wrapper,
            prev_allocation
        )
        if len(temp_pfs) > 0:
            init_cash = temp_pfs[-1].cash.iloc[-1]
            init_position = temp_pfs[-1].assets.iloc[-1]
            init_price = temp_pfs[-1].close.iloc[-1]
        else:
            init_cash = 100.
            init_position = 0.
            init_price = np.nan
        prev_pf = prev_pfo.simulate(
            sub_data,
            init_cash=init_cash,
            init_position=init_position,
            init_price=init_price
        )
        temp_pfs.append(prev_pf)
        should_rebalance = False
        curr_alloc = prev_pf.allocations.iloc[-1].values
        if (np.abs(curr_alloc - temp_allocations[-1]) >= threshold).any():
            should_rebalance = True
    else:
        should_rebalance = True
    n_symbols = len(sub_data.symbols)
    if should_rebalance:
        new_alloc = np.full(n_symbols, 1 / n_symbols)
    else:
        new_alloc = np.full(n_symbols, np.nan)
    temp_allocations.append(new_alloc)
    return new_alloc

pfs = []
allocations = []
pfopt = vbt.PortfolioOptimizer.from_optimize_func(
    data.symbol_wrapper,
    optimize_func,
    data,
    vbt.Rep("index_slice"),
    allocations,
    pfs,
    0.03,
    every="W"
)
pf = pfopt.simulate(data)

# %%
final_values = pd.concat(map(lambda x: x.value[[-1]], pfs))
final_values

# %%
pd.testing.assert_series_equal(
    final_values,
    pf.value.loc[final_values.index],
)

# %%