위치: D:\2023\TemporalFusionTransformer
설치 방법 (가상환경 torch310)
pip install pytorch-lightning
pip install pytorch_forecasting
# version < 2.0 for torch gpu version
# 2023년 3월. torch 2.0을 지원하지 않아, gpu버전 사용을 위해 따로 torch 설치
# numpy버전 맞지 않아 재 설치로 오류 해결
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1
--extra-index-url https://download.pytorch.org/whl/cu117
pip install numpy<1.24 # Solve version collision
import numpy as np
import pandas as pd
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
sample_data = pd.DataFrame(
dict(
time_idx=np.tile(np.arange(6), 3),
target=np.array([0,1,2,3,4,5,20,21,22,23,24,25,40,41,42,43,44,45]),
group=np.repeat(np.arange(3), 6),
holidays = np.tile(['X','Black Friday', 'X','Christmas','X', 'X'],3),
)
)
sample_data
# group이 3개(0, 1, 2), holidays는 각 group에 대해 일정
# time_index는 첫 col에 주어짐.
time_idx target group holidays
0 0 0 0 X
1 1 1 0 Black Friday
2 2 2 0 X
3 3 3 0 Christmas
4 4 4 0 X
5 5 5 0 X
6 0 20 1 X
7 1 21 1 Black Friday
8 2 22 1 X
9 3 23 1 Christmas
10 4 24 1 X
11 5 25 1 X
12 0 40 2 X
13 1 41 2 Black Friday
14 2 42 2 X
15 3 43 2 Christmas
16 4 44 2 X
17 5 45 2 X
#----------------------------------------------------
# create the time-series dataset from the pandas df
dataset = TimeSeriesDataSet(
sample_data,
group_ids=["group"],
target="target",
time_idx="time_idx",
max_encoder_length=2,
max_prediction_length=3,
time_varying_unknown_reals=["target"],
static_categoricals=["holidays"],
target_normalizer=None
)
# pass the dataset to a dataloader
dataloader = dataset.to_dataloader(batch_size=1)
#load the first batch
x, y = next(iter(dataloader))
print(x['encoder_target'])
print(x['groups'])
print(x['decoder_target'])
# 2개의 값을 encoder로 이용. 3개의 값이 prediction
tensor([[21., 22.]])
tensor([[1]])
tensor([[23., 24., 25.]])
#----------------------------------------------------
가구별 전력량 예측용 데이터의 경우
max_prediction_length = 24 # 이전 7일을 보고 1일 후를 예측
max_encoder_length = 7*24
# 마지막 하루(24시간) 빼고 학습
training_cutoff = time_df["hours_from_start"].max() - max_prediction_length
training = TimeSeriesDataSet(
time_df[lambda x: x.hours_from_start <= training_cutoff],
time_idx="hours_from_start",
target="power_usage",
group_ids=["consumer_id"], # group이 여러개
min_encoder_length=max_encoder_length // 2,
max_encoder_length=max_encoder_length,
min_prediction_length=1,
max_prediction_length=max_prediction_length,
static_categoricals=["consumer_id"], # 고객 id는 불변
time_varying_known_reals=["hours_from_start","day","day_of_week", "month", 'hour'],
time_varying_unknown_reals=['power_usage'],
# 정규화하기 전에 softplus변환 후에 정규화 실행(log/logp1/logit/relu등 있음)
# 각 group별로 정규화. group이 여러개 있고, 크기 범위가 다르다.
target_normalizer=GroupNormalizer(
groups=["consumer_id"], transformation="softplus"
), # we normalize by group
add_relative_time_idx=True,
add_target_scales=True,
add_encoder_length=True,
)
validation = TimeSeriesDataSet.from_dataset(training, time_df,
predict=True, stop_randomization=True)
# create dataloaders for our model
batch_size = 64
# to_dataloader를 통해, torch의 dataloader처럼 동작함
# if you have a strong GPU, feel free to increase the number of workers
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)
[References]
1. Medium blog TFT: https://towardsdatascience.com/temporal-fusion-transformer-time-series-forecasting-with-deep-learning-complete-tutorial-d32c1e51cd91
2. [2023년 4월]
-XGBoost, LightGBM: https://www.youtube.com/watch?v=4Jz4_IOgS4c
-WRN 코드: https://github.com/creinders/ChimeraMix/tree/main/models
-데이터 분석-클리닝: https://double-d.tistory.com/m/14
-데이터 정제와 정규화-사이킷런 기초: https://cyan91.tistory.com/m/40
-Data cleaning in 5 easy steps+Examples:
https://www.iteratorshq.com/blog/data-cleaning-in-5-easy-steps/
-우리가 pytorch lightning을 써야 하는 이유:
https://baeseongsu.github.io/posts/pytorch-lightning-introduction/
-트랜스포머 이해 굿: https://www.youtube.com/watch?v=AA621UofTUA
-OpenRefine 툴: https://www.youtube.com/watch?v=nORS7STbLyk / https://www.youtube.com/watch?v=oRH-1RG8oQY
-TFT 적용 예제: https://github.com/IKKIM00/stock-and-pm2.5-prediction-using-TFT / https://dacon.io/competitions/official/235736/data