LSTM-Network mit mehreren Eingängen möglich?

CyborgBeta

Banned
Registriert
Jan. 2021
Beiträge
3.958
Hi,

ich habe mir [1] durchgelesen und nach-implementiert.

Mein Problem ist, dass ich nicht jeweils nur einen Eingang habe (x_0, x_1, usw.), sondern 4:

x_0_0, x_0_1, x_0_2, x_0_3, x_1_0, x_1_1, x_1_2, x_1_3, usw.

Wie kann ich diese 4 als Eingang (und später als Ausgang) verwenden?
 
Aber (eine Rückfrage sei gestattet) ... ich werde daraus noch nicht so richtig schlau:

1725208648154.png


Ich habe einen (Edit:) 4-dimensionalen mehrdimensionalen Input, ja - aber brauche doch aus jedem 4-Tupel (Quadrupel) genau einen numerischen Wert x_i. Wie erhalte ich diesen Wert?
 
Zuletzt bearbeitet:
Vielleicht ginge das mit 2D Convolutional Neural Networks: https://colah.github.io/posts/2014-07-Conv-Nets-Modular/

denn die Eingabe ist ja auch zweidimensional. Aber Conv Nets sind natürlich keine LSTM Nets ...

Klar, man könnte vor jedem x_i ein Conv Net und nach jedem h_i auch ein Conv Net schalten, aber das ergäbe überhaupt keinen Sinn mehr, da Informationen (durch die "Engpässe") entfallen würden.
 
Hallo, ich habe mich der Einfachheit wegen entschlossen, diese API zu verwenden: https://github.com/deepjavalibrary/djl/blob/master/api/src/main/java/ai/djl/nn/recurrent/LSTM.java https://javadoc.io/doc/ai.djl/api/latest/ai/djl/nn/recurrent/LSTM.html

Wüsste zufällig jemand, wie ich das Model mit den 4-Tupeln programmatisch trainieren kann?

Im Folgenden ist ein Skeleton, das aber natürlich (noch) nicht funktioniert ...

Java:
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.recurrent.LSTM;
import ai.djl.training.ParameterStore;

import java.util.List;

public class Predict {
    public static List<Main.Candle> predict(List<Main.Candle> candles, int nPredicts) {
        LSTM lstm = new LSTM.Builder().setNumLayers(4).build();
        lstm.prepare(new Shape[]{new Shape(candles.size() - 1, 4)});
        NDList forward = lstm.forward(new ParameterStore(), new NDList(candles.size() - 2), true);
        //???
        for (int i = 0; i < candles.size() - 1; i++) {
            Main.Candle c = candles.get(i);
            System.out.println(c.openTime() + " " + c.closeTime() + " " + c.open() + " " + c.low() + " " + c.high() + " " + c.close());
        }
    }
}

Ich möchte nPredicts viele Vorhersagen für die Zukunft generieren. size() - 1, weil die letzte Kerze noch nicht abgeschlossen ist, also die Gegenwart angibt.
 
Zurück
Oben