-
Notifications
You must be signed in to change notification settings - Fork 43
Description
mvtseries = pd.read_csv(
"../assets/daily_multivariate_timeseries.csv",
parse_dates=["datetime"],
index_col="datetime",
)
N_LAGS = 14
HORIZON = 7
TARGET = ["Incoming Solar", "Air Temp", "Vapor Pressure"]
n_vars = mvtseries.shape[1]
class MultivariateSeriesDataModule(pl.LightningDataModule):
class MultiOutputLSTM(pl.LightningModule):
model = MultiOutputLSTM(
input_dim=n_vars, hidden_dim=32, num_layers=2, horizon=HORIZON, n_output=len(TARGET)
)
datamodule = MultivariateSeriesDataModule(
data=mvtseries,
n_lags=N_LAGS,
horizon=HORIZON,
test_size=0.3,
batch_size=32,
target_variables=TARGET,
)
early_stop_callback = EarlyStopping(
monitor="val_loss", min_delta=1e-4, patience=10, verbose=False, mode="min"
)
trainer = pl.Trainer(max_epochs=30, callbacks=[early_stop_callback])
trainer.fit(model, datamodule)
trainer.test(model, datamodule.test_dataloader())
forecasts = trainer.predict(model=model, datamodule=datamodule)