jax support added/multiinput

This commit is contained in:
David Brazda
2023-12-26 18:25:25 +01:00
parent 17b9859a73
commit 77faa919c0
4 changed files with 85 additions and 48 deletions

View File

@ -8,7 +8,7 @@ import numpy as np
from collections import defaultdict
"""
Předpokladm, že buď používáme 1) bar+standard indikator 2) cbars indicators - zatim nepodporovano spolecne (jine time rozliseni)
"""
def model(state, params, ind_name):
funcName = "model"
@ -32,25 +32,28 @@ def model(state, params, ind_name):
try:
mdl = state.vars.loaded_models[name]
#Optimalizovano, aby se v kazde iteraci nemusel volat len
if state.cache.get(name, {}).get("skip_init", False) is False:
if mdl.use_cbars is False:
if len(state.bars["close"]) < mdl.input_sequences:
return 0, 0
else:
state.cache[name]["skip_init"] = True
state.cache[name]["indicators"] = state.indicators
state.cache[name]["bars"] = state.bars if mdl.use_bars else {}
#return -2, f"too soon - not enough data for seq {seq=}"
else:
if len(state.cbar_indicators["time"]) < mdl.input_sequences:
return 0, 0
else:
state.cache[name]["skip_init"] = True
state.cache[name]["indicators"] = state.cbar_indicators
state.cache[name]["bars"] = state.bars if mdl.use_bars else {}
# #Optimalizovano, aby se v kazde iteraci nemusel volat len
# if state.cache.get(name, {}).get("skip_init", False) is False:
# if mdl.use_cbars is False:
# if len(state.bars["close"]) < mdl.input_sequences:
# return 0, 0
# else:
# state.cache[name]["skip_init"] = True
# state.cache[name]["indicators"] = state.indicators
# state.cache[name]["bars"] = state.bars if mdl.use_bars else {}
# #return -2, f"too soon - not enough data for seq {seq=}"
# else:
# if len(state.cbar_indicators["time"]) < mdl.input_sequences:
# return 0, 0
# else:
# state.cache[name]["skip_init"] = True
# state.cache[name]["indicators"] = state.cbar_indicators
# state.cache[name]["bars"] = state.bars if mdl.use_bars else {}
value = mdl.predict(state.cache[name]["bars"], state.cache[name]["indicators"])
# value = mdl.predict(state.cache[name]["bars"], state.cache[name]["indicators"])
value = mdl.predict(state)
return 0, value
except Exception as e:
printanyway(str(e)+format_exc())

View File

@ -1,3 +1,5 @@
import os
os.environ["KERAS_BACKEND"] = "jax"
from v2realbot.strategy.base import StrategyState
from v2realbot.strategy.StrategyOrderLimitVykladaciNormalizedMYSELL import StrategyOrderLimitVykladaciNormalizedMYSELL
from v2realbot.enums.enums import RecordType, StartBarAlign, Mode, Account, Followup
@ -15,7 +17,6 @@ import numpy as np
#from icecream import install, ic
from rich import print as printanyway
from threading import Event
import os
from traceback import format_exc
def initialize_dynamic_indicators(state):