jax support added/multiinput

This commit is contained in:
David Brazda
2023-12-26 18:25:25 +01:00
parent 17b9859a73
commit 77faa919c0
4 changed files with 85 additions and 48 deletions

View File

@ -59,7 +59,7 @@ Jinja2==3.1.2
joblib==1.3.2 joblib==1.3.2
jsonschema==4.17.3 jsonschema==4.17.3
jupyterlab-widgets==3.0.9 jupyterlab-widgets==3.0.9
keras==2.15.0 keras
kiwisolver==1.4.4 kiwisolver==1.4.4
libclang==16.0.6 libclang==16.0.6
llvmlite==0.39.1 llvmlite==0.39.1
@ -70,7 +70,8 @@ matplotlib==3.8.2
matplotlib-inline==0.1.6 matplotlib-inline==0.1.6
mdurl==0.1.2 mdurl==0.1.2
ml-dtypes==0.2.0 ml-dtypes==0.2.0
mlroom @ git+https://github.com/drew2323/mlroom.git@768c88348a0bd24c244a8720c67abb20fcb1403e mlroom @ git+https://github.com/drew2323/mlroom.git
keras-tcn @ git+https://github.com/drew2323/keras-tcn.git
mplfinance==0.12.10b0 mplfinance==0.12.10b0
msgpack==1.0.4 msgpack==1.0.4
mypy-extensions==1.0.0 mypy-extensions==1.0.0
@ -131,11 +132,6 @@ streamlit==1.20.0
structlog==23.1.0 structlog==23.1.0
TA-Lib==0.4.28 TA-Lib==0.4.28
tenacity==8.2.2 tenacity==8.2.2
tensorboard==2.15.1
tensorboard-data-server==0.7.1
tensorflow==2.15.0
tensorflow-estimator==2.15.0
tensorflow-io-gcs-filesystem==0.34.0
termcolor==2.3.0 termcolor==2.3.0
threadpoolctl==3.2.0 threadpoolctl==3.2.0
tinydb==4.7.1 tinydb==4.7.1

View File

@ -71,29 +71,66 @@ $(document).ready(function() {
}); });
} }
// function downloadModel(modelName) {
// $.ajax({
// url: '/model/download-model/' + modelName,
// type: 'GET',
// processData: false,
// contentType: false,
// responseType: 'blob', // This is important
// beforeSend: function (xhr) {
// xhr.setRequestHeader('X-API-Key', API_KEY);
// },
// success: function(data, status, xhr) {
// // Get a URL for the blob to download
// var blob = new Blob([data], { type: 'application/octet-stream' });
// //var blob = new Blob([data], { type: xhr.getResponseHeader('Content-Type') });
// var downloadUrl = URL.createObjectURL(blob);
// var a = document.createElement('a');
// a.href = downloadUrl;
// a.download = modelName;
// document.body.appendChild(a);
// a.click();
// // Clean up
// window.URL.revokeObjectURL(downloadUrl);
// a.remove();
// },
// error: function(xhr, status, error) {
// alert('Error downloading model: ' + error + xhr.responseText + status);
// }
// });
// }
function downloadModel(modelName) { function downloadModel(modelName) {
$.ajax({ fetch('/model/download-model/' + modelName, {
url: '/model/download-model/' + modelName, method: 'GET', // GET is the default method, but it's good to be explicit
type: 'GET', headers: {
beforeSend: function (xhr) { 'X-API-Key': API_KEY
xhr.setRequestHeader('X-API-Key', API_KEY);
},
success: function(data, status, xhr) {
// Get a URL for the blob to download
var blob = new Blob([data], { type: xhr.getResponseHeader('Content-Type') });
var downloadUrl = URL.createObjectURL(blob);
var a = document.createElement('a');
a.href = downloadUrl;
a.download = modelName;
document.body.appendChild(a);
a.click();
// Clean up
window.URL.revokeObjectURL(downloadUrl);
a.remove();
},
error: function(xhr, status, error) {
alert('Error downloading model: ' + error + xhr.responseText + status);
} }
})
.then(response => {
if (response.ok) return response.blob();
throw new Error('Network response was not ok.');
})
.then(blob => {
// Check the size of the blob here; it should match the Content-Length from the server
console.log('Size of downloaded blob:', blob.size);
// Create a link element, use it for download, and remove it
let url = window.URL.createObjectURL(blob);
let a = document.createElement('a');
a.style.display = 'none';
a.href = url;
a.download = modelName;
document.body.appendChild(a);
a.click();
window.setTimeout(() => {
document.body.removeChild(a);
window.URL.revokeObjectURL(url);
}, 100); // Cleanup after a small delay
})
.catch(error => {
console.error('Download error:', error);
}); });
} }

View File

@ -8,7 +8,7 @@ import numpy as np
from collections import defaultdict from collections import defaultdict
""" """
Předpokladm, že buď používáme 1) bar+standard indikator 2) cbars indicators - zatim nepodporovano spolecne (jine time rozliseni)
""" """
def model(state, params, ind_name): def model(state, params, ind_name):
funcName = "model" funcName = "model"
@ -32,25 +32,28 @@ def model(state, params, ind_name):
try: try:
mdl = state.vars.loaded_models[name] mdl = state.vars.loaded_models[name]
#Optimalizovano, aby se v kazde iteraci nemusel volat len # #Optimalizovano, aby se v kazde iteraci nemusel volat len
if state.cache.get(name, {}).get("skip_init", False) is False: # if state.cache.get(name, {}).get("skip_init", False) is False:
if mdl.use_cbars is False: # if mdl.use_cbars is False:
if len(state.bars["close"]) < mdl.input_sequences: # if len(state.bars["close"]) < mdl.input_sequences:
return 0, 0 # return 0, 0
else: # else:
state.cache[name]["skip_init"] = True # state.cache[name]["skip_init"] = True
state.cache[name]["indicators"] = state.indicators # state.cache[name]["indicators"] = state.indicators
state.cache[name]["bars"] = state.bars if mdl.use_bars else {} # state.cache[name]["bars"] = state.bars if mdl.use_bars else {}
#return -2, f"too soon - not enough data for seq {seq=}" # #return -2, f"too soon - not enough data for seq {seq=}"
else: # else:
if len(state.cbar_indicators["time"]) < mdl.input_sequences: # if len(state.cbar_indicators["time"]) < mdl.input_sequences:
return 0, 0 # return 0, 0
else: # else:
state.cache[name]["skip_init"] = True # state.cache[name]["skip_init"] = True
state.cache[name]["indicators"] = state.cbar_indicators # state.cache[name]["indicators"] = state.cbar_indicators
state.cache[name]["bars"] = state.bars if mdl.use_bars else {} # state.cache[name]["bars"] = state.bars if mdl.use_bars else {}
# value = mdl.predict(state.cache[name]["bars"], state.cache[name]["indicators"])
value = mdl.predict(state)
value = mdl.predict(state.cache[name]["bars"], state.cache[name]["indicators"])
return 0, value return 0, value
except Exception as e: except Exception as e:
printanyway(str(e)+format_exc()) printanyway(str(e)+format_exc())

View File

@ -1,3 +1,5 @@
import os
os.environ["KERAS_BACKEND"] = "jax"
from v2realbot.strategy.base import StrategyState from v2realbot.strategy.base import StrategyState
from v2realbot.strategy.StrategyOrderLimitVykladaciNormalizedMYSELL import StrategyOrderLimitVykladaciNormalizedMYSELL from v2realbot.strategy.StrategyOrderLimitVykladaciNormalizedMYSELL import StrategyOrderLimitVykladaciNormalizedMYSELL
from v2realbot.enums.enums import RecordType, StartBarAlign, Mode, Account, Followup from v2realbot.enums.enums import RecordType, StartBarAlign, Mode, Account, Followup
@ -15,7 +17,6 @@ import numpy as np
#from icecream import install, ic #from icecream import install, ic
from rich import print as printanyway from rich import print as printanyway
from threading import Event from threading import Event
import os
from traceback import format_exc from traceback import format_exc
def initialize_dynamic_indicators(state): def initialize_dynamic_indicators(state):