You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
222 lines
6.8 KiB
222 lines
6.8 KiB
5 days ago
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import math
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
|
||
|
class SoftTanhFunction(torch.autograd.Function):
|
||
|
'''
|
||
|
SoftTanh(x) = sign(x) * tanh(ln(1 + ln(1 + |x|)))
|
||
|
f(x) ∈ (-1, 1)
|
||
|
dy/dx ∈ (0, 1]
|
||
|
'''
|
||
|
@staticmethod
|
||
|
def forward(ctx, x):
|
||
|
abs_x = torch.abs(x)
|
||
|
u = torch.log1p(abs_x)
|
||
|
v = torch.log1p(u)
|
||
|
tanh_v = torch.tanh(v)
|
||
|
y = torch.sign(x) * tanh_v
|
||
|
ctx.save_for_backward(abs_x)
|
||
|
return y
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
abs_x, = ctx.saved_tensors
|
||
|
u = torch.log1p(abs_x)
|
||
|
v = torch.log1p(u)
|
||
|
tanh_v = torch.tanh(v)
|
||
|
sech_v_square = 1 - tanh_v.square()
|
||
|
denominator = (1 + u) * (1 + abs_x)
|
||
|
d_x = sech_v_square / denominator
|
||
|
grad_x = grad_output * d_x
|
||
|
return grad_x
|
||
|
|
||
|
|
||
|
def SoftTanh(x):
|
||
|
return SoftTanhFunction.apply(x)
|
||
|
|
||
|
|
||
|
class SoftSigmoidFunction(torch.autograd.Function):
|
||
|
'''
|
||
|
SoftSigmoid(x) = (sign(x) * tanh(ln(1 + ln(1 + |2x|))) + 1) / 2
|
||
|
f(x) ∈ (0, 1)
|
||
|
dy/dx ∈ (0, 1]
|
||
|
'''
|
||
|
@staticmethod
|
||
|
def forward(ctx, x):
|
||
|
abs_2x = torch.abs(2 * x)
|
||
|
u = torch.log1p(abs_2x)
|
||
|
v = torch.log1p(u)
|
||
|
tanh_v = torch.tanh(v)
|
||
|
y = (torch.sign(x) * tanh_v + 1) * 0.5
|
||
|
ctx.save_for_backward(abs_2x)
|
||
|
return y
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
abs_2x, = ctx.saved_tensors
|
||
|
u = torch.log1p(abs_2x)
|
||
|
v = torch.log1p(u)
|
||
|
tanh_v = torch.tanh(v)
|
||
|
sech_v_square = 1 - tanh_v.square()
|
||
|
denominator = (1 + u) * (1 + abs_2x)
|
||
|
d_x = sech_v_square / denominator
|
||
|
grad_x = grad_output * d_x
|
||
|
return grad_x
|
||
|
|
||
|
|
||
|
def SoftSigmoid(x):
|
||
|
return SoftSigmoidFunction.apply(x)
|
||
|
|
||
|
|
||
|
class AdaptiveSoftTanhFunction(torch.autograd.Function):
|
||
|
'''
|
||
|
AdaptiveSoftTanh(x) = alpha * sign(x) * tanh(ln(1 + ln(1 + |x|))) + beta
|
||
|
f(x) ∈ (beta - |alpha|, beta + |alpha|)
|
||
|
dy/dx ∈ (0, alpha] if alpha > 0
|
||
|
dy/dx ∈ [alpha, 0) if alpha < 0
|
||
|
'''
|
||
|
@staticmethod
|
||
|
def forward(ctx, x, alpha, beta):
|
||
|
abs_x = torch.abs(x)
|
||
|
u = torch.log1p(abs_x)
|
||
|
v = torch.log1p(u)
|
||
|
tanh_v = torch.tanh(v)
|
||
|
y = torch.sign(x) * tanh_v * alpha + beta
|
||
|
ctx.save_for_backward(x, alpha)
|
||
|
return y
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
x, alpha, = ctx.saved_tensors
|
||
|
abs_x = torch.abs(x)
|
||
|
u = torch.log1p(abs_x)
|
||
|
v = torch.log1p(u)
|
||
|
tanh_v = torch.tanh(v)
|
||
|
sech_v_square = 1 - tanh_v.square()
|
||
|
denominator = (1 + u) * (1 + abs_x)
|
||
|
|
||
|
d_x = sech_v_square / denominator
|
||
|
|
||
|
grad_x = grad_output * d_x
|
||
|
grad_alpha = grad_output * torch.sign(x) * tanh_v
|
||
|
grad_beta = grad_output.clone()
|
||
|
|
||
|
sum_dims = [d for d in range(grad_output.dim()) if d != grad_output.dim() - 1]
|
||
|
if sum_dims:
|
||
|
grad_alpha = grad_alpha.sum(dim=sum_dims)
|
||
|
grad_beta = grad_beta.sum(dim=sum_dims)
|
||
|
|
||
|
return grad_x, grad_alpha, grad_beta
|
||
|
|
||
|
|
||
|
class AdaptiveSoftTanh(nn.Module):
|
||
|
def __init__(self, channels):
|
||
|
super().__init__()
|
||
|
self.alpha = nn.Parameter(torch.Tensor(channels))
|
||
|
self.beta = nn.Parameter(torch.zeros(channels))
|
||
|
nn.init.normal_(self.alpha, mean=0, std=math.sqrt(2 / channels))
|
||
|
|
||
|
def forward(self, x):
|
||
|
return AdaptiveSoftTanhFunction.apply(x, self.alpha, self.beta)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
print("=== 测试 x=0 的情况 ===")
|
||
|
|
||
|
# 创建包含0的输入张量
|
||
|
x = torch.tensor([0.0], dtype=torch.float32, requires_grad=True)
|
||
|
|
||
|
# 前向传播
|
||
|
y = SoftTanh(x)
|
||
|
print(f"前向输出 (x=0): {y.item()}")
|
||
|
|
||
|
# 反向传播
|
||
|
y.backward()
|
||
|
print(f"梯度值 (x=0): {x.grad.item()}")
|
||
|
|
||
|
# 验证梯度计算
|
||
|
grad_ok = torch.allclose(x.grad, torch.tensor([1.0]))
|
||
|
print(f"梯度验证: {'通过' if grad_ok else '失败'}")
|
||
|
|
||
|
x = torch.linspace(-10, 10, 10000, requires_grad=True)
|
||
|
|
||
|
y_soft = SoftTanh(x)
|
||
|
y_tanh = torch.tanh(x)
|
||
|
|
||
|
grad_soft = torch.autograd.grad(y_soft, x, torch.ones_like(y_soft))[0]
|
||
|
grad_tanh = torch.autograd.grad(y_tanh, x, torch.ones_like(y_tanh))[0]
|
||
|
|
||
|
x_np = x.detach().numpy()
|
||
|
y_soft_np = y_soft.detach().numpy()
|
||
|
y_tanh_np = y_tanh.detach().numpy()
|
||
|
grad_soft_np = grad_soft.detach().numpy()
|
||
|
grad_tanh_np = grad_tanh.detach().numpy()
|
||
|
|
||
|
plt.figure(figsize=(12, 10))
|
||
|
# function graph
|
||
|
plt.subplot(2, 1, 1)
|
||
|
plt.plot(x_np, y_soft_np, 'b-', linewidth=2, label='SoftTanh')
|
||
|
plt.plot(x_np, y_tanh_np, 'r--', linewidth=2, label='Tanh')
|
||
|
plt.title('Function Comparison')
|
||
|
plt.xlabel('x')
|
||
|
plt.ylabel('y')
|
||
|
plt.grid(True)
|
||
|
plt.legend()
|
||
|
# grad graph
|
||
|
plt.subplot(2, 1, 2)
|
||
|
plt.plot(x_np, grad_soft_np, 'g-', linewidth=2, label='SoftTanh Gradient')
|
||
|
plt.plot(x_np, grad_tanh_np, 'm--', linewidth=2, label='Tanh Gradient')
|
||
|
plt.title('Gradient Comparison')
|
||
|
plt.xlabel('x')
|
||
|
plt.ylabel('dy/dx')
|
||
|
plt.grid(True)
|
||
|
plt.legend()
|
||
|
plt.tight_layout()
|
||
|
plt.savefig('./softtanh_comparison.png', dpi=300)
|
||
|
plt.close()
|
||
|
|
||
|
# 使用对数空间生成x值,只考虑正半轴
|
||
|
x = torch.logspace(-2, 3, 100000, base=10, requires_grad=True)
|
||
|
|
||
|
y_soft = SoftTanh(x)
|
||
|
y_tanh = torch.tanh(x)
|
||
|
|
||
|
grad_soft = torch.autograd.grad(y_soft, x, torch.ones_like(y_soft))[0]
|
||
|
grad_tanh = torch.autograd.grad(y_tanh, x, torch.ones_like(y_tanh))[0]
|
||
|
|
||
|
# 转换为numpy数组用于绘图
|
||
|
x_np = x.detach().numpy()
|
||
|
y_soft_np = y_soft.detach().numpy()
|
||
|
y_tanh_np = y_tanh.detach().numpy()
|
||
|
grad_soft_np = grad_soft.detach().numpy()
|
||
|
grad_tanh_np = grad_tanh.detach().numpy()
|
||
|
|
||
|
# 创建图形和子图
|
||
|
plt.figure(figsize=(14, 12))
|
||
|
|
||
|
# 函数图像(对数坐标)
|
||
|
plt.subplot(2, 1, 1)
|
||
|
plt.loglog(x_np, 1 - y_soft_np, 'b-', linewidth=2, label='SoftTanh (1-y)')
|
||
|
plt.loglog(x_np, 1 - y_tanh_np, 'r--', linewidth=2, label='Tanh (1-y)')
|
||
|
plt.title('Function Asymptotic Behavior (Log-Log Scale)')
|
||
|
plt.xlabel('x (log scale)')
|
||
|
plt.ylabel('1 - y (log scale)')
|
||
|
plt.grid(True, which="both", ls="--")
|
||
|
plt.legend()
|
||
|
|
||
|
# 梯度图像(对数坐标)
|
||
|
plt.subplot(2, 1, 2)
|
||
|
plt.loglog(x_np, grad_soft_np, 'g-', linewidth=2, label='SoftTanh Gradient')
|
||
|
plt.loglog(x_np, grad_tanh_np, 'm--', linewidth=2, label='Tanh Gradient')
|
||
|
plt.title('Gradient Decay (Log-Log Scale)')
|
||
|
plt.xlabel('x (log scale)')
|
||
|
plt.ylabel('dy/dx (log scale)')
|
||
|
plt.grid(True, which="both", ls="--")
|
||
|
plt.legend()
|
||
|
|
||
|
plt.tight_layout()
|
||
|
plt.savefig('./softtanh_log_comparison.png', dpi=300)
|
||
|
plt.close()
|