Deep Learning

Tensor를 detach() 하는 이유

Kimhj 2023. 10. 19. 10:57

 

  • 모델을 학습시키면서 AUROC 지표를 뽑을 때, inference 결과를 저장해서 확인해야할 때가 있다.
  • 그러면 예측 결과와 확률값, 타겟값(ground truth)을 numpy array 에 담거나 list 에 넣어서 저장해야 하는데, 이때 가끔 detach() 에러가 발생한다.
  • detach 하는 이유는 주로 그래디언트 계산과 자동 미분(autograd) 시스템과 관련이 있음.
    • 그래디언트 추적 중단: PyTorch와 같은 딥러닝 프레임워크에서는 기본적으로 텐서의 연산을 추적하여 그래디언트(미분)를 자동으로 계산함. 딥러닝 모델을 학습할 때 텐서에 대한 그래디언트를 추적하는 것은 중요하지만, 때로는 그래디언트를 추적하지 않아야 하는 경우가 있는데, 이때 .detach()를 사용하여 그래디언트 추적을 중단함.
    • 메모리 관리: 그래디언트 추적을 중단하면 그래디언트를 저장하는 메모리를 절약할 수 있다. 특히 중간 연산 결과를 저장하거나 특정 텐서의 그래디언트를 필요로하지 않는 경우, .detach()를 사용하여 메모리 사용량을 줄일 수 있음.
    • In-place 연산:  in-place 연산(in-place operation)으로 인해 원본 텐서에 영향을 미치는 것을 방지하기 위해 .detach()를 사용할 수 있음.
  • 예시 코드
import torch

# 그래디언트 추적을 중단하기 위해 detach()를 사용
x = torch.tensor([2.0], requires_grad=True)
y = x**2
z = y.detach()  # z는 그래디언트가 추적되지 않는 텐서

# y와 z를 사용한 연산
result = y + z

# result.backward()  # result에 대한 그래디언트를 계산하면 z는 그래디언트가 0으로 처리됨