전반적인 트레이닝 코드는 다음과 같습니다.
from dataset.monte_carlo_dataset import JaxMonteCarloDataset
import mesh.model as mm
from jax_capnn.model import create_mlp, CapNN
from jax.experimental import optimizers
from jax import jit, grad, vmap
import jax
import jax.numpy as jnp
from jax_capnn.hessian import hessian
from util.io import Logger
# 1. create setting
setting = {"num_iter": 100000,
"input_dim": 2, # 2 or 3
"print_iter": 1000,
"batch_size": 1200,
"lr": 1e-3,
"run_name": "sample",
"interior_train_start": 0.3,
"environment": "Thick2DCapacitor",
"layers": [30, 30, 30, 30] }
num_iter = setting['num_iter']
print_iter = setting['print_iter']
batch_size = setting['batch_size']
input_dim = setting['input_dim']
env = mm.create_environment_by_name(setting["environment"])
# 2. create model
cap_nn = CapNN(setting)
cap_nn.init_for_train()
# 3. create dataset
dataset = JaxMonteCarloDataset(env, num_iter, batch_size, input_dim, boundary_sample_ratio = 0.1)
logger = Logger(setting['run_name'])
logger.save_setting(setting)
print("Start training")
# 4. optimize
for i in dataset.iter():
data = dataset.get_item(i)
bx = data['boundary']
nx = data['non_boundary']
y = data['bc_value']
nv = data['bc_normal']
bx = env.normalize_point(bx)
nx = env.normalize_point(nx)
y = env.normalize_potential(y)
data = {"nx": nx, "bx": bx, "y": y, "nv": nv}
cap_nn.set_input(data)
cap_nn.optimize(i)
if i % print_iter == 0:
losses = cap_nn.get_losses()
logger.save_loss(losses)
print(losses)
cap_nn.print_lr(i)
pass
# save model at last
logger.save_model(cap_nn.get_net_params())
간단히 이야기 하면 크게 4 단계로 나누어 집니다.
셋팅은 네트워크, learning rate 등등 파라메터를 정의 합니다. 중요하게 보셔야 하는 것은 'layers' 인데요 이 부분이 실제 네트워크를 정의 합니다.
아래에서 모델 생성과 최적화 과정에 대해서 설명 하겠습니다.
모델 (ANN)은 다음과 같이 만들 수 있습니다. 아래 num_channels에 셋팅의 'layers' 가 들어가게 됩니다. stax.Dense
가 matrix를 정의하고 stax.Softplus
가 activation 함수 입니다. activation 함수는 다양하게 있습니다. Softplus 가 잘 동작 해서 일단 softplus로 했습니다.
def create_mlp(num_channels = []):
modules = []
for nc in num_channels:
modules.append(stax.Dense(nc))
modules.append(stax.Softplus)
modules.append(stax.Dense(1))
return stax.serial(*modules)
실제로 second derivative를 구하는 것은 loss 계산에서 이루어 집니다.
@jit
def step(i, opt_state, data):
p = self.get_params(opt_state)
g = grad(loss)(p, data)
return self.opt_update(i, g, opt_state)
def loss(params, input):
b_gain = input['b_gain']
h_gain = input['h_gain']
j_gain = input['j_gain']
b_loss = boundary_loss(params, input) * b_gain
h_loss = hessian_loss(params, input) * h_gain
j_loss = jacobian_loss(params, input) * j_gain
return b_loss + h_loss + j_loss
jacobian loss를 제외하고 boundary loss와 hessian loss (second derivative loss)가 기존의 PDE에서 사용하는 조건들 입니다. 각각은 다음과 같이 정의 되어있습니다.
def boundary_loss(params, input):
bx, nx, y, nv = self._unpack_input(input)
targets = y
predictions = self.net_apply(params, bx)
loss = jnp.mean((targets - predictions)**2)
return loss
def hessian_loss(params, input):
bx, nx, y, nv = self._unpack_input(input)
f = lambda x: self.net_apply(params, x)
v_hessian = vmap(hessian(f))
H = v_hessian(nx)
#TODO: check if diagonal is write
h_diag = H.diagonal(0, 2, 3)
h_diag_sum = jnp.sum(h_diag , axis = -1)
loss = (h_diag_sum - self.laplacian_target_value) ** 2
return jnp.mean(loss)
boundary loss는 일반적인 loss와 같습니다. hessian loss가 network의 second derivative를 구하는 과정입니다. 여기서 hessian 함수는 jax에서 제공해 주는 함수 입니다.
jax가 특이한 것이 객체가 아니라 함수를 파라메터로 많이 사용 합니다. hessian
함수는 함수를 입력 받아서 그 함수의 hessian
함수를 출력 한다고 생각 하시면 됩니다. 그래서 결국 hessian
에 네트워크를 입력하여 hessian
함수를 얻고 그 함수에 네트워크 입력값을 넣으면 입력에 대한 출력의 미분 값을 구할 수 있는 것 입니다. vmap
은 단순히 함수를 각 열에 대해서 개별적으로 입력 되도록 만들어 주는 것 입니다.
네트워크는 두가지 입력이 있는데요, 네트워크 파라메터(params)와 입력(x) 입니다. net_apply(params, x)
는 네트워크를 계산하는 함수입니다. 그런데 저희가 계산하고자 하는 것은 입력에 대한 미분값이라 parameter가 고정된 함수가 필요 합니다. 그래서 f
라는 함수를 통해 변수가 x
뿐인 함수를 만들어서 hessian
에 넣습니다.