2023년 3월 26일 일요일

Pytorch forecasting 사용법

위치: D:\2023\TemporalFusionTransformer

설치 방법 (가상환경 torch310) 

pip install pytorch-lightning
pip install pytorch_forecasting
# version < 2.0 for torch gpu version
# 2023년 3월. torch 2.0을 지원하지 않아, gpu버전 사용을 위해 따로 torch 설치
# numpy버전 맞지 않아 재 설치로 오류 해결  
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1
--extra-index-url https://download.pytorch.org/whl/cu117
pip install numpy<1.24  # Solve version collision


import numpy as np
import pandas as pd
from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet

sample_data = pd.DataFrame(
    dict(
        time_idx=np.tile(np.arange(6), 3),
        target=np.array([0,1,2,3,4,5,20,21,22,23,24,25,40,41,42,43,44,45]),
        group=np.repeat(np.arange(3), 6),
        holidays = np.tile(['X','Black Friday', 'X','Christmas','X', 'X'],3),
    )
)
sample_data

# group이 3개(0, 1, 2), holidays는 각 group에 대해 일정
# time_index는 첫 col에 주어짐.
time_idx    target  group   holidays
0   0   0   0   X
1   1   1   0   Black Friday
2   2   2   0   X
3   3   3   0   Christmas
4   4   4   0   X
5   5   5   0   X
6   0   20  1   X
7   1   21  1   Black Friday
8   2   22  1   X
9   3   23  1   Christmas
10  4   24  1   X
11  5   25  1   X
12  0   40  2   X
13  1   41  2   Black Friday
14  2   42  2   X
15  3   43  2   Christmas
16  4   44  2   X
17  5   45  2   X
#----------------------------------------------------

# create the time-series dataset from the pandas df
dataset = TimeSeriesDataSet(
    sample_data,
    group_ids=["group"],
    target="target",
    time_idx="time_idx",
    max_encoder_length=2,
    max_prediction_length=3,
    time_varying_unknown_reals=["target"],
    static_categoricals=["holidays"],
    target_normalizer=None
)

# pass the dataset to a dataloader
dataloader = dataset.to_dataloader(batch_size=1)

#load the first batch
x, y = next(iter(dataloader))
print(x['encoder_target'])
print(x['groups'])
print(x['decoder_target'])

# 2개의 값을 encoder로 이용. 3개의 값이 prediction
tensor([[21., 22.]])
tensor([[1]])
tensor([[23., 24., 25.]])
#----------------------------------------------------


가구별 전력량 예측용 데이터의 경우

max_prediction_length = 24  # 이전 7일을 보고 1일 후를 예측
max_encoder_length = 7*24
# 마지막 하루(24시간) 빼고 학습
training_cutoff = time_df["hours_from_start"].max() - max_prediction_length

training = TimeSeriesDataSet(
    time_df[lambda x: x.hours_from_start <= training_cutoff],
    time_idx="hours_from_start",
    target="power_usage",
    group_ids=["consumer_id"],  # group이 여러개
    min_encoder_length=max_encoder_length // 2,
    max_encoder_length=max_encoder_length,
    min_prediction_length=1,
    max_prediction_length=max_prediction_length,
    static_categoricals=["consumer_id"],  # 고객 id는 불변
    time_varying_known_reals=["hours_from_start","day","day_of_week", "month", 'hour'],
    time_varying_unknown_reals=['power_usage'],
    # 정규화하기 전에 softplus변환 후에 정규화 실행(log/logp1/logit/relu등 있음)
    # 각 group별로 정규화. group이 여러개 있고, 크기 범위가 다르다.
    target_normalizer=GroupNormalizer(
        groups=["consumer_id"], transformation="softplus"
    ),  # we normalize by group
    add_relative_time_idx=True,
    add_target_scales=True,
    add_encoder_length=True,
)

validation = TimeSeriesDataSet.from_dataset(training, time_df,
predict=True, stop_randomization=True)

# create dataloaders for  our model
batch_size = 64
# to_dataloader를 통해, torch의 dataloader처럼 동작함
# if you have a strong GPU, feel free to increase the number of workers  
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 10, num_workers=0)



[References]

1. Medium blog TFT: https://towardsdatascience.com/temporal-fusion-transformer-time-series-forecasting-with-deep-learning-complete-tutorial-d32c1e51cd91

2. [2023년 4월]

-XGBoost, LightGBM: https://www.youtube.com/watch?v=4Jz4_IOgS4c

-WRN 코드: https://github.com/creinders/ChimeraMix/tree/main/models

-데이터 분석-클리닝: https://double-d.tistory.com/m/14

-데이터 정제와 정규화-사이킷런 기초: https://cyan91.tistory.com/m/40

-Data cleaning in 5 easy steps+Examples: 

https://www.iteratorshq.com/blog/data-cleaning-in-5-easy-steps/

-우리가 pytorch lightning을 써야 하는 이유: 

https://baeseongsu.github.io/posts/pytorch-lightning-introduction/

-트랜스포머 이해 굿: https://www.youtube.com/watch?v=AA621UofTUA

-OpenRefine 툴: https://www.youtube.com/watch?v=nORS7STbLyk / https://www.youtube.com/watch?v=oRH-1RG8oQY

-TFT 적용 예제: https://github.com/IKKIM00/stock-and-pm2.5-prediction-using-TFT / https://dacon.io/competitions/official/235736/data


댓글 없음:

댓글 쓰기