jax support added/multiinput

This commit is contained in:
David Brazda
2023-12-26 18:25:25 +01:00
parent 17b9859a73
commit 77faa919c0
4 changed files with 85 additions and 48 deletions

View File

@ -59,7 +59,7 @@ Jinja2==3.1.2
joblib==1.3.2
jsonschema==4.17.3
jupyterlab-widgets==3.0.9
keras==2.15.0
keras
kiwisolver==1.4.4
libclang==16.0.6
llvmlite==0.39.1
@ -70,7 +70,8 @@ matplotlib==3.8.2
matplotlib-inline==0.1.6
mdurl==0.1.2
ml-dtypes==0.2.0
mlroom @ git+https://github.com/drew2323/mlroom.git@768c88348a0bd24c244a8720c67abb20fcb1403e
mlroom @ git+https://github.com/drew2323/mlroom.git
keras-tcn @ git+https://github.com/drew2323/keras-tcn.git
mplfinance==0.12.10b0
msgpack==1.0.4
mypy-extensions==1.0.0
@ -131,11 +132,6 @@ streamlit==1.20.0
structlog==23.1.0
TA-Lib==0.4.28
tenacity==8.2.2
tensorboard==2.15.1
tensorboard-data-server==0.7.1
tensorflow==2.15.0
tensorflow-estimator==2.15.0
tensorflow-io-gcs-filesystem==0.34.0
termcolor==2.3.0
threadpoolctl==3.2.0
tinydb==4.7.1

View File

@ -71,29 +71,66 @@ $(document).ready(function() {
});
}
// function downloadModel(modelName) {
// $.ajax({
// url: '/model/download-model/' + modelName,
// type: 'GET',
// processData: false,
// contentType: false,
// responseType: 'blob', // This is important
// beforeSend: function (xhr) {
// xhr.setRequestHeader('X-API-Key', API_KEY);
// },
// success: function(data, status, xhr) {
// // Get a URL for the blob to download
// var blob = new Blob([data], { type: 'application/octet-stream' });
// //var blob = new Blob([data], { type: xhr.getResponseHeader('Content-Type') });
// var downloadUrl = URL.createObjectURL(blob);
// var a = document.createElement('a');
// a.href = downloadUrl;
// a.download = modelName;
// document.body.appendChild(a);
// a.click();
// // Clean up
// window.URL.revokeObjectURL(downloadUrl);
// a.remove();
// },
// error: function(xhr, status, error) {
// alert('Error downloading model: ' + error + xhr.responseText + status);
// }
// });
// }
function downloadModel(modelName) {
$.ajax({
url: '/model/download-model/' + modelName,
type: 'GET',
beforeSend: function (xhr) {
xhr.setRequestHeader('X-API-Key', API_KEY);
},
success: function(data, status, xhr) {
// Get a URL for the blob to download
var blob = new Blob([data], { type: xhr.getResponseHeader('Content-Type') });
var downloadUrl = URL.createObjectURL(blob);
var a = document.createElement('a');
a.href = downloadUrl;
a.download = modelName;
document.body.appendChild(a);
a.click();
// Clean up
window.URL.revokeObjectURL(downloadUrl);
a.remove();
},
error: function(xhr, status, error) {
alert('Error downloading model: ' + error + xhr.responseText + status);
fetch('/model/download-model/' + modelName, {
method: 'GET', // GET is the default method, but it's good to be explicit
headers: {
'X-API-Key': API_KEY
}
})
.then(response => {
if (response.ok) return response.blob();
throw new Error('Network response was not ok.');
})
.then(blob => {
// Check the size of the blob here; it should match the Content-Length from the server
console.log('Size of downloaded blob:', blob.size);
// Create a link element, use it for download, and remove it
let url = window.URL.createObjectURL(blob);
let a = document.createElement('a');
a.style.display = 'none';
a.href = url;
a.download = modelName;
document.body.appendChild(a);
a.click();
window.setTimeout(() => {
document.body.removeChild(a);
window.URL.revokeObjectURL(url);
}, 100); // Cleanup after a small delay
})
.catch(error => {
console.error('Download error:', error);
});
}

View File

@ -8,7 +8,7 @@ import numpy as np
from collections import defaultdict
"""
Předpokladm, že buď používáme 1) bar+standard indikator 2) cbars indicators - zatim nepodporovano spolecne (jine time rozliseni)
"""
def model(state, params, ind_name):
funcName = "model"
@ -32,25 +32,28 @@ def model(state, params, ind_name):
try:
mdl = state.vars.loaded_models[name]
#Optimalizovano, aby se v kazde iteraci nemusel volat len
if state.cache.get(name, {}).get("skip_init", False) is False:
if mdl.use_cbars is False:
if len(state.bars["close"]) < mdl.input_sequences:
return 0, 0
else:
state.cache[name]["skip_init"] = True
state.cache[name]["indicators"] = state.indicators
state.cache[name]["bars"] = state.bars if mdl.use_bars else {}
#return -2, f"too soon - not enough data for seq {seq=}"
else:
if len(state.cbar_indicators["time"]) < mdl.input_sequences:
return 0, 0
else:
state.cache[name]["skip_init"] = True
state.cache[name]["indicators"] = state.cbar_indicators
state.cache[name]["bars"] = state.bars if mdl.use_bars else {}
# #Optimalizovano, aby se v kazde iteraci nemusel volat len
# if state.cache.get(name, {}).get("skip_init", False) is False:
# if mdl.use_cbars is False:
# if len(state.bars["close"]) < mdl.input_sequences:
# return 0, 0
# else:
# state.cache[name]["skip_init"] = True
# state.cache[name]["indicators"] = state.indicators
# state.cache[name]["bars"] = state.bars if mdl.use_bars else {}
# #return -2, f"too soon - not enough data for seq {seq=}"
# else:
# if len(state.cbar_indicators["time"]) < mdl.input_sequences:
# return 0, 0
# else:
# state.cache[name]["skip_init"] = True
# state.cache[name]["indicators"] = state.cbar_indicators
# state.cache[name]["bars"] = state.bars if mdl.use_bars else {}
value = mdl.predict(state.cache[name]["bars"], state.cache[name]["indicators"])
# value = mdl.predict(state.cache[name]["bars"], state.cache[name]["indicators"])
value = mdl.predict(state)
return 0, value
except Exception as e:
printanyway(str(e)+format_exc())

View File

@ -1,3 +1,5 @@
import os
os.environ["KERAS_BACKEND"] = "jax"
from v2realbot.strategy.base import StrategyState
from v2realbot.strategy.StrategyOrderLimitVykladaciNormalizedMYSELL import StrategyOrderLimitVykladaciNormalizedMYSELL
from v2realbot.enums.enums import RecordType, StartBarAlign, Mode, Account, Followup
@ -15,7 +17,6 @@ import numpy as np
#from icecream import install, ic
from rich import print as printanyway
from threading import Event
import os
from traceback import format_exc
def initialize_dynamic_indicators(state):