import torch import torch.nn as nn import math import matplotlib.pyplot as plt class SoftTanhFunction(torch.autograd.Function): ''' SoftTanh(x) = sign(x) * tanh(ln(1 + ln(1 + |x|))) f(x) ∈ (-1, 1) dy/dx ∈ (0, 1] ''' @staticmethod def forward(ctx, x): abs_x = torch.abs(x) u = torch.log1p(abs_x) v = torch.log1p(u) tanh_v = torch.tanh(v) y = torch.sign(x) * tanh_v ctx.save_for_backward(abs_x) return y @staticmethod def backward(ctx, grad_output): abs_x, = ctx.saved_tensors u = torch.log1p(abs_x) v = torch.log1p(u) tanh_v = torch.tanh(v) sech_v_square = 1 - tanh_v.square() denominator = (1 + u) * (1 + abs_x) d_x = sech_v_square / denominator grad_x = grad_output * d_x return grad_x def SoftTanh(x): return SoftTanhFunction.apply(x) class SoftSigmoidFunction(torch.autograd.Function): ''' SoftSigmoid(x) = (sign(x) * tanh(ln(1 + ln(1 + |2x|))) + 1) / 2 f(x) ∈ (0, 1) dy/dx ∈ (0, 1] ''' @staticmethod def forward(ctx, x): abs_2x = torch.abs(2 * x) u = torch.log1p(abs_2x) v = torch.log1p(u) tanh_v = torch.tanh(v) y = (torch.sign(x) * tanh_v + 1) * 0.5 ctx.save_for_backward(abs_2x) return y @staticmethod def backward(ctx, grad_output): abs_2x, = ctx.saved_tensors u = torch.log1p(abs_2x) v = torch.log1p(u) tanh_v = torch.tanh(v) sech_v_square = 1 - tanh_v.square() denominator = (1 + u) * (1 + abs_2x) d_x = sech_v_square / denominator grad_x = grad_output * d_x return grad_x def SoftSigmoid(x): return SoftSigmoidFunction.apply(x) class AdaptiveSoftTanhFunction(torch.autograd.Function): ''' AdaptiveSoftTanh(x) = alpha * sign(x) * tanh(ln(1 + ln(1 + |x|))) + beta f(x) ∈ (beta - |alpha|, beta + |alpha|) dy/dx ∈ (0, alpha] if alpha > 0 dy/dx ∈ [alpha, 0) if alpha < 0 ''' @staticmethod def forward(ctx, x, alpha, beta): abs_x = torch.abs(x) u = torch.log1p(abs_x) v = torch.log1p(u) tanh_v = torch.tanh(v) y = torch.sign(x) * tanh_v * alpha + beta ctx.save_for_backward(x, alpha) return y @staticmethod def backward(ctx, grad_output): x, alpha, = ctx.saved_tensors abs_x = torch.abs(x) u = torch.log1p(abs_x) v = torch.log1p(u) tanh_v = torch.tanh(v) sech_v_square = 1 - tanh_v.square() denominator = (1 + u) * (1 + abs_x) d_x = sech_v_square / denominator grad_x = grad_output * d_x grad_alpha = grad_output * torch.sign(x) * tanh_v grad_beta = grad_output.clone() sum_dims = [d for d in range(grad_output.dim()) if d != grad_output.dim() - 1] if sum_dims: grad_alpha = grad_alpha.sum(dim=sum_dims) grad_beta = grad_beta.sum(dim=sum_dims) return grad_x, grad_alpha, grad_beta class AdaptiveSoftTanh(nn.Module): def __init__(self, channels): super().__init__() self.alpha = nn.Parameter(torch.Tensor(channels)) self.beta = nn.Parameter(torch.zeros(channels)) nn.init.normal_(self.alpha, mean=0, std=math.sqrt(2 / channels)) def forward(self, x): return AdaptiveSoftTanhFunction.apply(x, self.alpha, self.beta) if __name__ == "__main__": print("=== 测试 x=0 的情况 ===") # 创建包含0的输入张量 x = torch.tensor([0.0], dtype=torch.float32, requires_grad=True) # 前向传播 y = SoftTanh(x) print(f"前向输出 (x=0): {y.item()}") # 反向传播 y.backward() print(f"梯度值 (x=0): {x.grad.item()}") # 验证梯度计算 grad_ok = torch.allclose(x.grad, torch.tensor([1.0])) print(f"梯度验证: {'通过' if grad_ok else '失败'}") x = torch.linspace(-10, 10, 10000, requires_grad=True) y_soft = SoftTanh(x) y_tanh = torch.tanh(x) grad_soft = torch.autograd.grad(y_soft, x, torch.ones_like(y_soft))[0] grad_tanh = torch.autograd.grad(y_tanh, x, torch.ones_like(y_tanh))[0] x_np = x.detach().numpy() y_soft_np = y_soft.detach().numpy() y_tanh_np = y_tanh.detach().numpy() grad_soft_np = grad_soft.detach().numpy() grad_tanh_np = grad_tanh.detach().numpy() plt.figure(figsize=(12, 10)) # function graph plt.subplot(2, 1, 1) plt.plot(x_np, y_soft_np, 'b-', linewidth=2, label='SoftTanh') plt.plot(x_np, y_tanh_np, 'r--', linewidth=2, label='Tanh') plt.title('Function Comparison') plt.xlabel('x') plt.ylabel('y') plt.grid(True) plt.legend() # grad graph plt.subplot(2, 1, 2) plt.plot(x_np, grad_soft_np, 'g-', linewidth=2, label='SoftTanh Gradient') plt.plot(x_np, grad_tanh_np, 'm--', linewidth=2, label='Tanh Gradient') plt.title('Gradient Comparison') plt.xlabel('x') plt.ylabel('dy/dx') plt.grid(True) plt.legend() plt.tight_layout() plt.savefig('./softtanh_comparison.png', dpi=300) plt.close() # 使用对数空间生成x值,只考虑正半轴 x = torch.logspace(-2, 3, 100000, base=10, requires_grad=True) y_soft = SoftTanh(x) y_tanh = torch.tanh(x) grad_soft = torch.autograd.grad(y_soft, x, torch.ones_like(y_soft))[0] grad_tanh = torch.autograd.grad(y_tanh, x, torch.ones_like(y_tanh))[0] # 转换为numpy数组用于绘图 x_np = x.detach().numpy() y_soft_np = y_soft.detach().numpy() y_tanh_np = y_tanh.detach().numpy() grad_soft_np = grad_soft.detach().numpy() grad_tanh_np = grad_tanh.detach().numpy() # 创建图形和子图 plt.figure(figsize=(14, 12)) # 函数图像(对数坐标) plt.subplot(2, 1, 1) plt.loglog(x_np, 1 - y_soft_np, 'b-', linewidth=2, label='SoftTanh (1-y)') plt.loglog(x_np, 1 - y_tanh_np, 'r--', linewidth=2, label='Tanh (1-y)') plt.title('Function Asymptotic Behavior (Log-Log Scale)') plt.xlabel('x (log scale)') plt.ylabel('1 - y (log scale)') plt.grid(True, which="both", ls="--") plt.legend() # 梯度图像(对数坐标) plt.subplot(2, 1, 2) plt.loglog(x_np, grad_soft_np, 'g-', linewidth=2, label='SoftTanh Gradient') plt.loglog(x_np, grad_tanh_np, 'm--', linewidth=2, label='Tanh Gradient') plt.title('Gradient Decay (Log-Log Scale)') plt.xlabel('x (log scale)') plt.ylabel('dy/dx (log scale)') plt.grid(True, which="both", ls="--") plt.legend() plt.tight_layout() plt.savefig('./softtanh_log_comparison.png', dpi=300) plt.close()