You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

222 lines
6.8 KiB

5 days ago
import torch
import torch.nn as nn
import math
import matplotlib.pyplot as plt
class SoftTanhFunction(torch.autograd.Function):
'''
SoftTanh(x) = sign(x) * tanh(ln(1 + ln(1 + |x|)))
f(x) (-1, 1)
dy/dx (0, 1]
'''
@staticmethod
def forward(ctx, x):
abs_x = torch.abs(x)
u = torch.log1p(abs_x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
y = torch.sign(x) * tanh_v
ctx.save_for_backward(abs_x)
return y
@staticmethod
def backward(ctx, grad_output):
abs_x, = ctx.saved_tensors
u = torch.log1p(abs_x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
sech_v_square = 1 - tanh_v.square()
denominator = (1 + u) * (1 + abs_x)
d_x = sech_v_square / denominator
grad_x = grad_output * d_x
return grad_x
def SoftTanh(x):
return SoftTanhFunction.apply(x)
class SoftSigmoidFunction(torch.autograd.Function):
'''
SoftSigmoid(x) = (sign(x) * tanh(ln(1 + ln(1 + |2x|))) + 1) / 2
f(x) (0, 1)
dy/dx (0, 1]
'''
@staticmethod
def forward(ctx, x):
abs_2x = torch.abs(2 * x)
u = torch.log1p(abs_2x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
y = (torch.sign(x) * tanh_v + 1) * 0.5
ctx.save_for_backward(abs_2x)
return y
@staticmethod
def backward(ctx, grad_output):
abs_2x, = ctx.saved_tensors
u = torch.log1p(abs_2x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
sech_v_square = 1 - tanh_v.square()
denominator = (1 + u) * (1 + abs_2x)
d_x = sech_v_square / denominator
grad_x = grad_output * d_x
return grad_x
def SoftSigmoid(x):
return SoftSigmoidFunction.apply(x)
class AdaptiveSoftTanhFunction(torch.autograd.Function):
'''
AdaptiveSoftTanh(x) = alpha * sign(x) * tanh(ln(1 + ln(1 + |x|))) + beta
f(x) (beta - |alpha|, beta + |alpha|)
dy/dx (0, alpha] if alpha > 0
dy/dx [alpha, 0) if alpha < 0
'''
@staticmethod
def forward(ctx, x, alpha, beta):
abs_x = torch.abs(x)
u = torch.log1p(abs_x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
y = torch.sign(x) * tanh_v * alpha + beta
ctx.save_for_backward(x, alpha)
return y
@staticmethod
def backward(ctx, grad_output):
x, alpha, = ctx.saved_tensors
abs_x = torch.abs(x)
u = torch.log1p(abs_x)
v = torch.log1p(u)
tanh_v = torch.tanh(v)
sech_v_square = 1 - tanh_v.square()
denominator = (1 + u) * (1 + abs_x)
d_x = sech_v_square / denominator
grad_x = grad_output * d_x
grad_alpha = grad_output * torch.sign(x) * tanh_v
grad_beta = grad_output.clone()
sum_dims = [d for d in range(grad_output.dim()) if d != grad_output.dim() - 1]
if sum_dims:
grad_alpha = grad_alpha.sum(dim=sum_dims)
grad_beta = grad_beta.sum(dim=sum_dims)
return grad_x, grad_alpha, grad_beta
class AdaptiveSoftTanh(nn.Module):
def __init__(self, channels):
super().__init__()
self.alpha = nn.Parameter(torch.Tensor(channels))
self.beta = nn.Parameter(torch.zeros(channels))
nn.init.normal_(self.alpha, mean=0, std=math.sqrt(2 / channels))
def forward(self, x):
return AdaptiveSoftTanhFunction.apply(x, self.alpha, self.beta)
if __name__ == "__main__":
print("=== 测试 x=0 的情况 ===")
# 创建包含0的输入张量
x = torch.tensor([0.0], dtype=torch.float32, requires_grad=True)
# 前向传播
y = SoftTanh(x)
print(f"前向输出 (x=0): {y.item()}")
# 反向传播
y.backward()
print(f"梯度值 (x=0): {x.grad.item()}")
# 验证梯度计算
grad_ok = torch.allclose(x.grad, torch.tensor([1.0]))
print(f"梯度验证: {'通过' if grad_ok else '失败'}")
x = torch.linspace(-10, 10, 10000, requires_grad=True)
y_soft = SoftTanh(x)
y_tanh = torch.tanh(x)
grad_soft = torch.autograd.grad(y_soft, x, torch.ones_like(y_soft))[0]
grad_tanh = torch.autograd.grad(y_tanh, x, torch.ones_like(y_tanh))[0]
x_np = x.detach().numpy()
y_soft_np = y_soft.detach().numpy()
y_tanh_np = y_tanh.detach().numpy()
grad_soft_np = grad_soft.detach().numpy()
grad_tanh_np = grad_tanh.detach().numpy()
plt.figure(figsize=(12, 10))
# function graph
plt.subplot(2, 1, 1)
plt.plot(x_np, y_soft_np, 'b-', linewidth=2, label='SoftTanh')
plt.plot(x_np, y_tanh_np, 'r--', linewidth=2, label='Tanh')
plt.title('Function Comparison')
plt.xlabel('x')
plt.ylabel('y')
plt.grid(True)
plt.legend()
# grad graph
plt.subplot(2, 1, 2)
plt.plot(x_np, grad_soft_np, 'g-', linewidth=2, label='SoftTanh Gradient')
plt.plot(x_np, grad_tanh_np, 'm--', linewidth=2, label='Tanh Gradient')
plt.title('Gradient Comparison')
plt.xlabel('x')
plt.ylabel('dy/dx')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.savefig('./softtanh_comparison.png', dpi=300)
plt.close()
# 使用对数空间生成x值,只考虑正半轴
x = torch.logspace(-2, 3, 100000, base=10, requires_grad=True)
y_soft = SoftTanh(x)
y_tanh = torch.tanh(x)
grad_soft = torch.autograd.grad(y_soft, x, torch.ones_like(y_soft))[0]
grad_tanh = torch.autograd.grad(y_tanh, x, torch.ones_like(y_tanh))[0]
# 转换为numpy数组用于绘图
x_np = x.detach().numpy()
y_soft_np = y_soft.detach().numpy()
y_tanh_np = y_tanh.detach().numpy()
grad_soft_np = grad_soft.detach().numpy()
grad_tanh_np = grad_tanh.detach().numpy()
# 创建图形和子图
plt.figure(figsize=(14, 12))
# 函数图像(对数坐标)
plt.subplot(2, 1, 1)
plt.loglog(x_np, 1 - y_soft_np, 'b-', linewidth=2, label='SoftTanh (1-y)')
plt.loglog(x_np, 1 - y_tanh_np, 'r--', linewidth=2, label='Tanh (1-y)')
plt.title('Function Asymptotic Behavior (Log-Log Scale)')
plt.xlabel('x (log scale)')
plt.ylabel('1 - y (log scale)')
plt.grid(True, which="both", ls="--")
plt.legend()
# 梯度图像(对数坐标)
plt.subplot(2, 1, 2)
plt.loglog(x_np, grad_soft_np, 'g-', linewidth=2, label='SoftTanh Gradient')
plt.loglog(x_np, grad_tanh_np, 'm--', linewidth=2, label='Tanh Gradient')
plt.title('Gradient Decay (Log-Log Scale)')
plt.xlabel('x (log scale)')
plt.ylabel('dy/dx (log scale)')
plt.grid(True, which="both", ls="--")
plt.legend()
plt.tight_layout()
plt.savefig('./softtanh_log_comparison.png', dpi=300)
plt.close()