본문 바로가기

pytorch 사용법

pytorch parameter 합치기, loss 합치기

파이토치 이용 시 가끔씩 헷갈리는 문법을 정리해보겠다.

 

 

 

#1 파라미터 한번에 업데이트 하기

 

optimizer = torch.optim.Adam(encoder_decoder_class.parameters(), lr = 1e-4)

 

만약 우리가 위처럼 인코더 모델(class)과 디코더 모델(class)을 하나로 이은 class가 아닌 각각의 class로 생성해서 가지고 있고, 이 두 클래스의 인스턴스의 변수를 한 번에 업데이트 하기위해서는 아래처럼 하면 된다.

 

parameters = list(encoder_class.parameters()) + list(decoder_class.parameters())

optimizer = torch.optim.Adam(parameters, lr = 1e-4)   

 

이렇게 두 인스턴스의 변수를 리스트로 만들어 합치면 된다.

parameters()함수는 제너레이터를 리턴하는데 이를 list로 변환해서 붙여준 것이다.

 

https://pytorch.org/docs/master/optim.html를 보면 알 수 있듯이 torch.optim.Adam은 제너레이터, 리스트, 딕셔너리를 모두 받을 수 있다.

밑의 캡처본은 하나의 예시로 iterable of variable 뿐만이 아니라 iterable of dict도 가능함을 보여준다. 

 

파이토치 공식 페이지 : https://pytorch.org/docs/master/optim.html

 

제너레이터는 간단하게 설명하면 이터레이터를 생성해 주는 함수다.

 

 

이터레이터란?

next()를 호출할 때 다음값을 생성해내는 상태를 가진 헬퍼 객체이다. next()를 가진 모든 객체는 이터레이터이다. 값을 생성해내는 방법과는 무관하다. 즉, 이터레이터는 값 생성기이다. 다음값을 요청할 때마다 내부 상태를 유지하고 있기 때문에 다음값을 계산하는 방법을 알고있다. 이터레이터(iterator)를 제공한다면 시퀀스가 아닌 타입도 for 루프로 탐색할 수 있다. 이터레이터는 내장 함수 iter() 를 사용해서 얻는다. 이렇게 얻은 이터레이터를 탐색하려면 내장 함수 next() 를 사용하면 된다.

 

next(iter())

 

 

제너레이터란?

이터레이터를 생성해 주는 함수다.

일반적인 함수는 사용이 종료되면 결과값을 호출부로 반환 후 함수 자체를 종료시킨 후 내부에서 사용된 데이터들을 메모리 상에서 클리어 한다. 하지만, 제너레이터 함수를 실행하는 중 yield를 만날 경우, 해당 함수는 그 상태로 정지 되며, 반환 값을 next() 를 호출한 쪽으로 전달 하게 된다. 이후 해당 함수는 종료되지 않고 그 상태를 유지하게 된다. 즉, 함수에서 사용된 local 변수나 instruction pointer 등과 같은 함수 내부에서 사용된 데이터들이 메모리에 그대로 유지되는 것이다.

 

제너레이터는 값에 데이터에 접근할 때 마다 메모리에 적재하기 때문에, 큰 데이터를 다루는 경우, 리스트보다 안정적이고, 효율적이다.

 

 

 

#2 loss 합치기

 

 

예를 들어 GAN에서 discriminator는 가짜 사진에 대한 loss 값(레이블 : 0)과 진짜 사진에 대한 loss값(레이블 : 1)이 모두 존재하며 이를 합한 값을 loss로 이용한다.

이런 경우, 각각 의 loss의 차원만 잘 맞다면 + 하면 된다.

 

 

간단한 예시)

optimizer = torch.optim.Adam(discriminator.parameters(), lr = learning_rate)

gen_fake = generator(x)

fake_pic = discriminator(gen_fake)

loss_func = nn.MSELoss()

discriminator_loss = torch.sum(loss_func(fake_pic, zero_label)) + torch.sum(loss_func(real_pic, one_label))

discriminator_loss.backward()

optimizer.step()