引言:FSRNet在超分辨率领域的里程碑意义
FSRNet(Face Super-Resolution Network)是2018年由腾讯AI Lab提出的专门针对人脸图像超分辨率的深度学习模型。与通用的图像超分辨率方法不同,FSRNet创新性地引入了先验知识引导的思想,通过结合人脸关键点热图(heatmap)和解析图(parsing map)作为辅助信息,实现了从低分辨率人脸图像到高分辨率人脸图像的高质量重建。
本文将深入剖析FSRNet的代码实现细节与核心算法原理,帮助读者全面理解这一经典模型的设计哲学与技术实现。
一、FSRNet核心算法原理
1.1 问题定义与动机
传统超分辨率方法在处理人脸图像时存在明显局限:
- 缺乏领域知识:通用方法未利用人脸结构先验
- 细节模糊:难以恢复精细的人脸特征(如眼睛、眉毛、嘴唇)
- 伪影严重:在极端低分辨率情况下容易产生失真
FSRNet通过引入多任务学习和先验引导机制,有效解决了上述问题。
1.2 网络架构总览
FSRNet采用编码器-解码器结构,包含三个关键组件:
- 特征提取编码器:从LR图像中提取多尺度特征
- 先验信息融合模块:整合关键点热图与解析图
- 渐进式重建解码器:分阶段生成HR图像
1.3 核心创新点
1.3.1 先验信息引导机制
FSRNet的核心创新在于利用人工先验而非依赖大量数据:
- 关键点热图(Keypoint Heatmap):标注人脸关键点(如眼、鼻、嘴)的高斯分布图
- 解析图(Parsing Map):语义分割结果,标注五官、头发等区域
这些先验信息在训练时提供,在推理时通过辅助网络预测得到。
1.3.2 渐进式重建策略
采用coarse-to-fine策略:
- 第一阶段:生成粗糙的HR图像
- 第二阶段:基于先验信息进行细节优化
二、FSRNet代码实现详解
2.1 环境依赖与项目结构
# requirements.txt
torch>=1.2.0
torchvision>=0.4.0
numpy>=1.16.0
opencv-python>=3.4.0
scipy>=1.2.0
Pillow>=6.0.0
# 典型项目结构
FSRNet/
├── models/
│ ├── __init__.py
│ ├── fsrnet.py # 主网络定义
│ ├── prior_net.py # 先验预测网络
│ └── loss.py # 损失函数
├── data/
│ ├── dataset.py # 数据集处理
│ └── transforms.py # 数据增强
├── utils/
│ ├── utils.py # 工具函数
│ └── face_utils.py # 人脸相关工具
├── config.py # 配置文件
└── train.py # 训练脚本
2.2 核心网络模块代码解析
2.2.1 主网络架构(fsrnet.py)
import torch
import torch.nn as nn
import torch.nn.functional as F
class FSRNet(nn.Module):
"""
FSRNet主网络:实现人脸超分辨率重建
输入:LR图像 (3, H, W)
输出:HR图像 (3, 4H, 4W)
"""
def __init__(self, scale_factor=4, num_channels=64, num_prior_channels=32):
super(FSRNet, self).__init__()
self.scale_factor = scale_factor
# 1. 编码器:特征提取
self.encoder = nn.Sequential(
# 初始卷积:提取基础特征
nn.Conv2d(3, num_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
# 残差块1:提取低级特征
ResidualBlock(num_channels, num_channels),
# 下采样:减少空间维度
nn.Conv2d(num_channels, num_channels*2, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
# 残差块2:提取中级特征
ResidualBlock(num_channels*2, num_channels*2),
# 下采样:进一步减少空间维度
nn.Conv2d(num_channels*2, num_channels*4, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
# 残差块3:提取高级特征
ResidualBlock(num_channels*4, num_channels*4),
)
# 2. 先验融合模块
self.prior_fusion = PriorFusionModule(
in_channels=num_channels*4,
prior_channels=num_prior_channels
)
# 3. 解码器:渐进式重建
self.decoder = nn.Sequential(
# 上采样1:从低分辨率恢复
nn.ConvTranspose2d(num_channels*4, num_channels*2,
kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
ResidualBlock(num_channels*2, num_channels*2),
# 上采样2:进一步恢复细节
nn.ConvTranspose2d(num_channels*2, num_channels,
kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
ResidualBlock(num_channels, num_channels),
# 输出层:生成最终HR图像
nn.Conv2d(num_channels, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh() # 输出范围[-1,1],需后处理
)
def forward(self, x, prior_info=None):
"""
前向传播
Args:
x: LR图像 tensor (B, 3, H, W)
prior_info: 先验信息 dict {
'heatmap': (B, 68, H, W), # 68个关键点
'parsing': (B, 19, H, W) # 19类语义分割
}
Returns:
HR图像 tensor (B, 3, 4H, 4W)
"""
# 编码
features = self.encoder(x)
# 先验融合(训练时使用,推理时预测)
if prior_info is not None:
fused_features = self.prior_fusion(features, prior_info)
else:
# 推理模式:使用预测的先验
fused_features = features
# 解码重建
hr = self.decoder(fused_features)
# 后处理:调整到[0,1]范围
hr = (hr + 1) / 2.0
return hr
class ResidualBlock(nn.Module):
"""残差块:解决深层网络梯度消失问题"""
def __init__(self, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
return out + residual # 恒等映射
class PriorFusionModule(nn.Module):
"""先验融合模块:整合关键点热图和解析图"""
def __init__(self, in_channels, prior_channels):
super(PriorFusionModule, self).__init__()
# 1. 先验特征提取器
self.prior_encoder = nn.Sequential(
nn.Conv2d(prior_channels, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
# 2. 特征调制器:使用先验调制主特征
self.modulation = nn.Sequential(
nn.Conv2d(in_channels + 64, in_channels, kernel_size=1),
nn.ReLU(inplace=True),
)
def forward(self, main_features, prior_info):
"""
融合主特征与先验信息
Args:
main_features: (B, C, H', W')
prior_info: dict with 'heatmap' and 'parsing'
"""
# 拼接先验信息
prior_input = torch.cat([
prior_info['heatmap'], # (B, 68, H, W)
prior_info['parsing'] # (B, 19, H, W)
], dim=1) # (B, 87, H, W)
# 提取先验特征
prior_features = self.prior_encoder(prior_input) # (B, 64, H, W)
# 上采样先验特征到与主特征相同尺寸
prior_features = F.interpolate(
prior_features,
size=main_features.shape[2:],
mode='bilinear',
align_corners=False
)
# 拼接并调制
fused = torch.cat([main_features, prior_features], dim=1)
modulated = self.modulation(fused)
return modulated
2.3 先验预测网络(prior_net.py)
在推理时,需要预测关键点热图和解析图:
class PriorNet(nn.Module):
"""
先验预测网络:从LR图像预测关键点热图和解析图
采用轻量级设计,便于快速推理
"""
def __init__(self, num_keypoints=68, num_classes=19):
super(PriorNet, self).__init__()
# 共享特征提取器
self.shared_encoder = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
)
# 关键点热图预测头
self.heatmap_head = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, num_keypoints, kernel_size=1),
nn.Sigmoid() # 输出概率图
)
# 解析图预测头
self.parsing_head = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, num_classes, kernel_size=1),
nn.Softmax(dim=1) # 输出分类概率
)
def forward(self, x):
"""
预测先验信息
Args:
x: LR图像 (B, 3, H, W)
Returns:
dict: {'heatmap': ..., 'parsing': ...}
"""
features = self.shared_encoder(x)
# 上采样到原始LR尺寸
heatmap = self.heatmap_head(features)
heatmap = F.interpolate(
heatmap,
size=x.shape[2:],
mode='bilinear',
align_corners=False
)
parsing = self.parsing_head(features)
parsing = F.interpolate(
parsing,
size=x.shape[2:],
mode='bilinear',
align_corners=False
)
return {'heatmap': heatmap, 'parsing': parsing}
2.4 损失函数设计(loss.py)
FSRNet采用多任务损失,包含多个分量:
class FSRNetLoss(nn.Module):
"""
FSRNet复合损失函数
L_total = L_recon + λ1*L_prior + λ2*L_perceptual + λ3*L_adversarial
"""
def __init__(self, lambda_prior=0.1, lambda_perceptual=0.01, lambda_adv=0.001):
super(FSRNetLoss, self).__init__()
self.lambda_prior = lambda_prior
self.lambda_perceptual = lambda_perceptual
self.lambda_adv = lambda_adv
# 重建损失(L1)
self.recon_loss = nn.L1Loss()
# 先验损失(关键点和解析图)
self.prior_loss = nn.MSELoss() # 热图用MSE
# 感知损失(使用预训练VGG)
self.perceptual_loss = PerceptualLoss()
# 对抗损失(可选)
self.adversarial_loss = nn.BCEWithLogitsLoss()
def forward(self, sr_output, hr_target, prior_info, prior_pred):
"""
计算总损失
Args:
sr_output: 网络输出的SR图像
hr_target: 真实HR图像
prior_info: 真实先验信息
prior_pred: 预测的先验信息
"""
# 1. 重建损失
loss_recon = self.recon_loss(sr_output, hr_target)
# 2. 先验损失(预测准确性)
loss_prior = 0
if prior_info is not None and prior_pred is not None:
# 关键点热图损失
loss_heatmap = self.prior_loss(
prior_pred['heatmap'],
prior_info['heatmap']
)
# 解析图损失(交叉熵)
loss_parsing = F.cross_entropy(
prior_pred['parsing'],
prior_info['parsing'].argmax(dim=1) # 转为类别标签
)
loss_prior = loss_heatmap + loss_parsing
# 3. 感知损失
loss_perceptual = self.perceptual_loss(sr_output, hr_target)
# 4. 对抗损失(训练判别器时使用)
loss_adversarial = 0 # 在判别器训练中计算
# 总损失
total_loss = (
loss_recon +
self.lambda_prior * loss_prior +
self.lambda_perceptual * loss_perceptual +
self.lambda_adv * loss_adversarial
)
return {
'total': total_loss,
'recon': loss_recon,
'prior': loss_prior,
'perceptual': loss_perceptual,
}
class PerceptualLoss(nn.Module):
"""感知损失:基于VGG19的特征空间距离"""
def __init__(self, layer_names=['relu3_3', 'relu4_3']):
super(PerceptualLoss, self).__init__()
self.vgg = VGG19Features(layer_names)
# 冻结VGG参数
for param in self.vgg.parameters():
param.requires_grad = False
def forward(self, sr, hr):
# 提取特征
sr_features = self.vgg(sr)
hr_features = self.vgg(hr)
# 计算特征距离
loss = 0
for sr_feat, hr_feat in zip(sr_features, hr_features):
loss += F.mse_loss(sr_feat, hr_feat)
return loss / len(sr_features)
2.5 训练流程(train.py)
def train_one_epoch(model, prior_net, dataloader, optimizer, loss_fn, device):
"""
单轮训练函数
"""
model.train()
prior_net.train()
total_loss = 0
for batch_idx, (lr_images, hr_images, priors) in enumerate(dataloader):
lr_images = lr_images.to(device)
hr_images = hr_images.to(device)
# 1. 预测先验信息
prior_pred = prior_net(lr_images)
# 2. 超分辨率重建
sr_output = model(lr_images, prior_pred)
# 3. 计算损失
losses = loss_fn(sr_output, hr_images, priors, prior_pred)
# 4. 反向传播
optimizer.zero_grad()
losses['total'].backward()
optimizer.step()
total_loss += losses['total'].item()
if batch_idx % 100 == 0:
print(f"Batch {batch_idx}: Loss={losses['total'].item():.4f}")
return total_loss / len(dataloader)
def train(model, prior_net, train_loader, val_loader, config):
"""
完整训练流程
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
prior_net = prior_net.to(device)
# 优化器
optimizer = torch.optim.Adam(
[
{'params': model.parameters()},
{'params': prior_net.parameters(), 'lr': config.lr * 0.1} # 先验网络学习率小一些
],
lr=config.lr,
betas=(0.9, 0.999)
)
# 学习率调度器
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5
)
# 损失函数
loss_fn = FSRNetLoss(
lambda_prior=config.lambda_prior,
lambda_perceptual=config.lambda_perceptual
)
best_psnr = 0
for epoch in range(config.num_epochs):
# 训练
train_loss = train_one_epoch(
model, prior_net, train_loader, optimizer, loss_fn, device
)
# 验证
val_psnr = validate(model, prior_net, val_loader, device)
# 调度学习率
scheduler.step(-val_psnr) # PSNR越高越好
# 保存最佳模型
if val_psnr > best_psnr:
best_psnr = val_psnr
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'prior_net_state_dict': prior_net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'psnr': val_psnr
}, 'best_model.pth')
print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Val PSNR={val_psnr:.2f}")
2.6 推理代码实现
class FSRNetInference:
"""FSRNet推理封装类"""
def __init__(self, model_path, prior_net_path=None, device='cuda'):
"""
初始化推理器
Args:
model_path: FSRNet模型路径
prior_net_path: 先验预测网络路径(若为None则使用纯FSRNet)
"""
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
# 加载FSRNet
self.model = FSRNet(scale_factor=4)
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
# 加载先验网络(可选)
self.prior_net = None
if prior_net_path:
self.prior_net = PriorNet()
prior_checkpoint = torch.load(prior_net_path, map_location=self.device)
self.prior_net.load_state_dict(prior_checkpoint['model_state_dict'])
self.prior_net.to(self.device)
self.prior_net.eval()
# 图像预处理
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def preprocess(self, image_path):
"""预处理LR图像"""
image = Image.open(image_path).convert('RGB')
# 确保尺寸是4的倍数(便于下采样)
w, h = image.size
w = w - (w % 4)
h = h - (h % 4)
image = image.crop((0, 0, w, h))
# 转换为tensor
lr_tensor = self.transform(image).unsqueeze(0) # (1, 3, H, W)
return lr_tensor.to(self.device), image
def postprocess(self, sr_tensor):
"""后处理:tensor转PIL图像"""
sr_tensor = sr_tensor.squeeze(0).cpu()
sr_image = transforms.ToPILImage()(sr_tensor)
return sr_image
def __call__(self, image_path, use_prior=True):
"""
推理入口
Args:
image_path: LR图像路径
use_prior: 是否使用先验信息
Returns:
HR图像 (PIL)
"""
lr_tensor, lr_image = self.preprocess(image_path)
with torch.no_grad():
if use_prior and self.prior_net is not None:
# 预测先验
prior_pred = self.prior_net(lr_tensor)
# 超分辨率
sr_tensor = self.model(lr_tensor, prior_pred)
else:
# 纯FSRNet(性能较差)
sr_tensor = self.model(lr_tensor)
return self.postprocess(sr_tensor)
# 使用示例
if __name__ == '__main__':
# 初始化推理器
fsrnet = FSRNetInference(
model_path='checkpoints/fsrnet_best.pth',
prior_net_path='checkpoints/prior_net_best.pth'
)
# 推理单张图像
hr_image = fsrnet('data/lr_face.jpg')
hr_image.save('output/hr_face.jpg')
print(f"超分辨率完成!输出尺寸: {hr_image.size}")
三、数据准备与处理
3.1 数据集格式要求
FSRNet需要成对的LR-HR图像以及先验信息:
class FSRNetDataset(torch.utils.data.Dataset):
"""FSRNet数据集:需要LR、HR、关键点、解析图"""
def __init__(self, data_dir, split='train', scale_factor=4):
"""
Args:
data_dir: 数据根目录
split: 'train' or 'val'
scale_factor: 超分辨率倍数
"""
self.scale_factor = scale_factor
self.data_dir = data_dir
# 数据文件列表
self.image_list = []
with open(os.path.join(data_dir, f'{split}.txt'), 'r') as f:
for line in f:
self.image_list.append(line.strip())
# 数据增强
self.transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.1, contrast=0.1),
])
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
base_name = self.image_list[idx]
# 1. 加载HR图像
hr_path = os.path.join(self.data_dir, 'HR', f'{base_name}.png')
hr_image = Image.open(hr_path).convert('RGB')
# 2. 生成LR图像(双三次插值下采样)
w, h = hr_image.size
lr_w, lr_h = w // self.scale_factor, h // self.scale_factor
lr_image = hr_image.resize((lr_w, lr_h), Image.BICUBIC)
# 3. 加载先验信息
# 关键点热图
heatmap_path = os.path.join(self.data_dir, 'heatmap', f'{base_name}.npy')
heatmap = np.load(heatmap_path) # (68, H, W)
# 解析图
parsing_path = os.path.join(self.data_dir, 'parsing', f'{base_name}.png')
parsing = Image.open(parsing_path)
parsing = np.array(parsing) # (H, W)
# 转换为one-hot
parsing_onehot = np.eye(19)[parsing] # (H, W, 19)
parsing_onehot = parsing_onehot.transpose(2, 0, 1) # (19, H, W)
# 4. 数据增强(仅训练)
if self.transform and split == 'train':
# 对HR和LR同步增强
seed = np.random.randint(0, 2**32)
random.seed(seed)
torch.manual_seed(seed)
hr_image = self.transform(hr_image)
random.seed(seed)
torch.manual_seed(seed)
lr_image = self.transform(lr_image)
# 5. 转换为Tensor
to_tensor = transforms.ToTensor()
lr_tensor = to_tensor(lr_image)
hr_tensor = to_tensor(hr_image)
# 归一化到[-1,1]
lr_tensor = lr_tensor * 2 - 1
hr_tensor = hr_tensor * 2 - 1
# 先验信息tensor
heatmap_tensor = torch.from_numpy(heatmap).float()
parsing_tensor = torch.from_numpy(parsing_onehot).float()
prior_info = {
'heatmap': heatmap_tensor,
'parsing': parsing_tensor
}
return lr_tensor, hr_tensor, prior_info
3.2 先验信息生成
在训练前,需要使用外部工具生成先验信息:
import cv2
import dlib
from scipy.ndimage import gaussian_filter
def generate_prior_info(hr_image_path, output_dir):
"""
生成训练所需的先验信息
Args:
hr_image_path: HR图像路径
output_dir: 输出目录
"""
# 加载HR图像
image = cv2.imread(hr_image_path)
h, w = image.shape[:2]
# 1. 生成关键点热图
# 使用dlib检测68个关键点
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
faces = detector(gray)
if len(faces) == 0:
print(f"未检测到人脸: {hr_image_path}")
return
# 获取关键点
landmarks = predictor(gray, faces[0])
keypoints = []
for i in range(68):
x = landmarks.part(i).x
y = landmarks.part(i).y
keypoints.append((x, y))
# 生成高斯热图
heatmap = np.zeros((68, h, w), dtype=np.float32)
for i, (x, y) in enumerate(keypoints):
# 在关键点位置生成2D高斯分布
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
dist = (xx - x)**2 + (yy - y)**2
sigma = 3.0 # 高斯核标准差
gaussian = np.exp(-dist / (2 * sigma**2))
heatmap[i] = gaussian
# 2. 生成解析图
# 使用人脸解析模型(如BiSeNet)
parsing = generate_parsing_map(image) # 返回(H,W)的类别索引图
# 保存
base_name = os.path.splitext(os.path.basename(hr_image_path))[0]
np.save(os.path.join(output_dir, 'heatmap', f'{base_name}.npy'), heatmap)
cv2.imwrite(os.path.join(output_dir, 'parsing', f'{base_name}.png'), parsing)
print(f"生成先验信息完成: {base_name}")
def generate_parsing_map(image):
"""
生成人脸解析图(伪代码,需接入实际模型)
"""
# 实际项目中使用预训练的人脸解析模型
# 例如: https://github.com/zllrunning/face-parsing.PyTorch
# 这里仅示意
from face_parsing import BiSeNet
model = BiSeNet(n_classes=19)
model.load_state_dict(torch.load('bisenet.pth'))
model.eval()
# 预处理
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_tensor = to_tensor(image).unsqueeze(0)
with torch.no_grad():
parsing = model(input_tensor)[0] # (1, 19, H, W)
parsing = parsing.argmax(dim=1).squeeze(0).cpu().numpy()
return parsing
四、关键实现细节与技巧
4.1 训练策略
4.1.1 分阶段训练
def two_stage_training():
"""
两阶段训练策略
"""
# 阶段1:仅训练FSRNet,固定先验网络(如果已预训练)
# 阶段2:联合微调两个网络
# 阶段1:预训练先验网络
optimizer_prior = torch.optim.Adam(prior_net.parameters(), lr=1e-3)
for epoch in range(50):
# 仅计算先验损失
prior_pred = prior_net(lr_images)
loss = prior_loss(prior_pred, priors)
loss.backward()
optimizer_prior.step()
# 阶段2:联合训练
optimizer_joint = torch.optim.Adam([
{'params': model.parameters()},
{'params': prior_net.parameters(), 'lr': 1e-4}
], lr=1e-3)
for epoch in range(100):
# 联合计算所有损失
prior_pred = prior_net(lr_images)
sr_output = model(lr_images, prior_pred)
loss = loss_fn(sr_output, hr_images, priors, prior_pred)
loss.backward()
optimizer_joint.step()
4.1.2 混合精度训练
from torch.cuda.amp import autocast, GradScaler
def train_mixed_precision(model, prior_net, dataloader, optimizer, loss_fn, device):
scaler = GradScaler()
for lr_images, hr_images, priors in dataloader:
lr_images = lr_images.to(device)
hr_images = hr_images.to(device)
optimizer.zero_grad()
# 混合精度前向传播
with autocast():
prior_pred = prior_net(lr_images)
sr_output = model(lr_images, prior_pred)
losses = loss_fn(sr_output, hr_images, priors, prior_pred)
# 缩放梯度并反向传播
scaler.scale(losses['total']).backward()
scaler.step(optimizer)
scaler.update()
4.2 推理优化
4.2.1 模型量化(INT8)
def quantize_model(model):
"""模型量化以加速推理"""
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
# 校准(使用少量数据)
# torch.quantization.convert(model, inplace=True)
return model
4.2.2 TensorRT加速
import torch_tensorrt
def convert_to_tensorrt(model, input_shape):
"""转换为TensorRT引擎"""
# 转换为TorchScript
traced_model = torch.jit.trace(model, torch.randn(1, 3, *input_shape).cuda())
# 编译TensorRT
trt_model = torch_tensorrt.compile(
traced_model,
inputs=[torch.randn(1, 3, *input_shape).cuda()],
enabled_precisions={torch.float16}, # FP16
workspace_size=1 << 30,
truncate_long_and_double=True
)
return trt_model
4.3 评估指标
def calculate_metrics(sr_image, hr_image):
"""
计算超分辨率评估指标
"""
import lpips
# PSNR
psnr = peak_signal_to_noise_ratio(sr_image, hr_image)
# SSIM
ssim = structural_similarity(sr_image, hr_image, multichannel=True)
# LPIPS(感知相似度)
lpips_loss = lpips.LPIPS(net='vgg')
lpips_score = lpips_loss(sr_image, hr_image)
return {
'PSNR': psnr,
'SSIM': ssim,
'LPIPS': lpips_score.item()
}
def peak_signal_to_noise_ratio(sr, hr):
"""PSNR计算"""
mse = np.mean((sr - hr) ** 2)
if mse == 0:
return 100
return 20 * np.log10(255.0 / np.sqrt(mse))
def structural_similarity(sr, hr, multichannel=True):
"""SSIM计算"""
from skimage.metrics import structural_similarity as ssim
return ssim(sr, hr, multichannel=multichannel, channel_axis=2 if multichannel else None)
五、常见问题与解决方案
5.1 训练不稳定
问题:损失震荡,不收敛 解决方案:
- 使用梯度裁剪:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - 降低学习率:初始lr设为1e-4
- 使用学习率warmup:前几个epoch线性增加学习率
5.2 先验网络预测不准
问题:先验信息质量差,反而影响重建效果 解决方案:
- 预训练先验网络:单独训练PriorNet直到收敛
- 数据增强:对LR图像进行抖动,提升鲁棒性
- 损失权重调整:增大先验损失权重(λ_prior)
5.3 显存不足
问题:大batch size导致OOM 解决方案:
- 梯度累积:
accumulation_steps = 4 - 混合精度训练
- 减小网络通道数:
num_channels=32
5.4 推理速度慢
问题:单张图像推理时间过长 解决方案:
- 使用ONNX导出并优化
- TensorRT加速(可提升3-5倍)
- 批量推理:一次处理多张图像
六、FSRNet的改进与变体
6.1 FSRNet++
2019年提出的改进版本,主要改进:
- 注意力机制:在先验融合中引入通道注意力
- 更精细的多尺度特征:使用U-Net结构替代简单编解码
class FSRNetPlus(nn.Module):
"""FSRNet++:引入注意力机制"""
def __init__(self):
super().__init__()
self.encoder = UNetEncoder() # 改用U-Net
self.attention_fusion = AttentionPriorFusion()
self.decoder = UNetDecoder()
6.2 与其他模型的结合
- ESRGAN + FSRNet:引入ESRGAN的判别器
- ARCNN + FSRNet:先压缩再重建,处理极端低分辨率
七、总结
FSRNet通过先验知识引导的创新思路,在人脸超分辨率领域取得了突破性进展。其核心价值在于:
- 领域知识融合:将人脸结构先验嵌入网络
- 多任务学习:联合优化重建与先验预测
- 渐进式重建:coarse-to-fine的生成策略
在实际应用中,需要注意:
- 数据准备:高质量的先验信息是关键
- 训练技巧:分阶段训练、混合精度
- 推理优化:模型量化、TensorRT加速
通过本文的代码解析,读者应能掌握FSRNet的完整实现流程,并可根据实际需求进行定制化改进。
参考文献:
- Chen, Y., et al. “FSRNet: Face Super-Resolution Network with Prior Knowledge.” CVPR 2018.
- 代码参考:https://github.com/xyfJASON/FSRNet
- 先验生成工具:dlib, BiSeNet# 深入解析FSRNet代码实现细节与核心算法原理剖析
引言:FSRNet在超分辨率领域的里程碑意义
FSRNet(Face Super-Resolution Network)是2018年由腾讯AI Lab提出的专门针对人脸图像超分辨率的深度学习模型。与通用的图像超分辨率方法不同,FSRNet创新性地引入了先验知识引导的思想,通过结合人脸关键点热图(heatmap)和解析图(parsing map)作为辅助信息,实现了从低分辨率人脸图像到高分辨率人脸图像的高质量重建。
本文将深入剖析FSRNet的代码实现细节与核心算法原理,帮助读者全面理解这一经典模型的设计哲学与技术实现。
一、FSRNet核心算法原理
1.1 问题定义与动机
传统超分辨率方法在处理人脸图像时存在明显局限:
- 缺乏领域知识:通用方法未利用人脸结构先验
- 细节模糊:难以恢复精细的人脸特征(如眼睛、眉毛、嘴唇)
- 伪影严重:在极端低分辨率情况下容易产生失真
FSRNet通过引入多任务学习和先验引导机制,有效解决了上述问题。
1.2 网络架构总览
FSRNet采用编码器-解码器结构,包含三个关键组件:
- 特征提取编码器:从LR图像中提取多尺度特征
- 先验信息融合模块:整合关键点热图与解析图
- 渐进式重建解码器:分阶段生成HR图像
1.3 核心创新点
1.3.1 先验信息引导机制
FSRNet的核心创新在于利用人工先验而非依赖大量数据:
- 关键点热图(Keypoint Heatmap):标注人脸关键点(如眼、鼻、嘴)的高斯分布图
- 解析图(Parsing Map):语义分割结果,标注五官、头发等区域
这些先验信息在训练时提供,在推理时通过辅助网络预测得到。
1.3.2 渐进式重建策略
采用coarse-to-fine策略:
- 第一阶段:生成粗糙的HR图像
- 第二阶段:基于先验信息进行细节优化
二、FSRNet代码实现详解
2.1 环境依赖与项目结构
# requirements.txt
torch>=1.2.0
torchvision>=0.4.0
numpy>=1.16.0
opencv-python>=3.4.0
scipy>=1.2.0
Pillow>=6.0.0
# 典型项目结构
FSRNet/
├── models/
│ ├── __init__.py
│ ├── fsrnet.py # 主网络定义
│ ├── prior_net.py # 先验预测网络
│ └── loss.py # 损失函数
├── data/
│ ├── dataset.py # 数据集处理
│ └── transforms.py # 数据增强
├── utils/
│ ├── utils.py # 工具函数
│ └── face_utils.py # 人脸相关工具
├── config.py # 配置文件
└── train.py # 训练脚本
2.2 核心网络模块代码解析
2.2.1 主网络架构(fsrnet.py)
import torch
import torch.nn as nn
import torch.nn.functional as F
class FSRNet(nn.Module):
"""
FSRNet主网络:实现人脸超分辨率重建
输入:LR图像 (3, H, W)
输出:HR图像 (3, 4H, 4W)
"""
def __init__(self, scale_factor=4, num_channels=64, num_prior_channels=32):
super(FSRNet, self).__init__()
self.scale_factor = scale_factor
# 1. 编码器:特征提取
self.encoder = nn.Sequential(
# 初始卷积:提取基础特征
nn.Conv2d(3, num_channels, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
# 残差块1:提取低级特征
ResidualBlock(num_channels, num_channels),
# 下采样:减少空间维度
nn.Conv2d(num_channels, num_channels*2, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
# 残差块2:提取中级特征
ResidualBlock(num_channels*2, num_channels*2),
# 下采样:进一步减少空间维度
nn.Conv2d(num_channels*2, num_channels*4, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
# 残差块3:提取高级特征
ResidualBlock(num_channels*4, num_channels*4),
)
# 2. 先验融合模块
self.prior_fusion = PriorFusionModule(
in_channels=num_channels*4,
prior_channels=num_prior_channels
)
# 3. 解码器:渐进式重建
self.decoder = nn.Sequential(
# 上采样1:从低分辨率恢复
nn.ConvTranspose2d(num_channels*4, num_channels*2,
kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
ResidualBlock(num_channels*2, num_channels*2),
# 上采样2:进一步恢复细节
nn.ConvTranspose2d(num_channels*2, num_channels,
kernel_size=4, stride=2, padding=1),
nn.ReLU(inplace=True),
ResidualBlock(num_channels, num_channels),
# 输出层:生成最终HR图像
nn.Conv2d(num_channels, 3, kernel_size=3, stride=1, padding=1),
nn.Tanh() # 输出范围[-1,1],需后处理
)
def forward(self, x, prior_info=None):
"""
前向传播
Args:
x: LR图像 tensor (B, 3, H, W)
prior_info: 先验信息 dict {
'heatmap': (B, 68, H, W), # 68个关键点
'parsing': (B, 19, H, W) # 19类语义分割
}
Returns:
HR图像 tensor (B, 3, 4H, 4W)
"""
# 编码
features = self.encoder(x)
# 先验融合(训练时使用,推理时预测)
if prior_info is not None:
fused_features = self.prior_fusion(features, prior_info)
else:
# 推理模式:使用预测的先验
fused_features = features
# 解码重建
hr = self.decoder(fused_features)
# 后处理:调整到[0,1]范围
hr = (hr + 1) / 2.0
return hr
class ResidualBlock(nn.Module):
"""残差块:解决深层网络梯度消失问题"""
def __init__(self, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
return out + residual # 恒等映射
class PriorFusionModule(nn.Module):
"""先验融合模块:整合关键点热图和解析图"""
def __init__(self, in_channels, prior_channels):
super(PriorFusionModule, self).__init__()
# 1. 先验特征提取器
self.prior_encoder = nn.Sequential(
nn.Conv2d(prior_channels, 32, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
)
# 2. 特征调制器:使用先验调制主特征
self.modulation = nn.Sequential(
nn.Conv2d(in_channels + 64, in_channels, kernel_size=1),
nn.ReLU(inplace=True),
)
def forward(self, main_features, prior_info):
"""
融合主特征与先验信息
Args:
main_features: (B, C, H', W')
prior_info: dict with 'heatmap' and 'parsing'
"""
# 拼接先验信息
prior_input = torch.cat([
prior_info['heatmap'], # (B, 68, H, W)
prior_info['parsing'] # (B, 19, H, W)
], dim=1) # (B, 87, H, W)
# 提取先验特征
prior_features = self.prior_encoder(prior_input) # (B, 64, H, W)
# 上采样先验特征到与主特征相同尺寸
prior_features = F.interpolate(
prior_features,
size=main_features.shape[2:],
mode='bilinear',
align_corners=False
)
# 拼接并调制
fused = torch.cat([main_features, prior_features], dim=1)
modulated = self.modulation(fused)
return modulated
2.3 先验预测网络(prior_net.py)
在推理时,需要预测关键点热图和解析图:
class PriorNet(nn.Module):
"""
先验预测网络:从LR图像预测关键点热图和解析图
采用轻量级设计,便于快速推理
"""
def __init__(self, num_keypoints=68, num_classes=19):
super(PriorNet, self).__init__()
# 共享特征提取器
self.shared_encoder = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.ReLU(inplace=True),
)
# 关键点热图预测头
self.heatmap_head = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, num_keypoints, kernel_size=1),
nn.Sigmoid() # 输出概率图
)
# 解析图预测头
self.parsing_head = nn.Sequential(
nn.Conv2d(128, 64, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(64, num_classes, kernel_size=1),
nn.Softmax(dim=1) # 输出分类概率
)
def forward(self, x):
"""
预测先验信息
Args:
x: LR图像 (B, 3, H, W)
Returns:
dict: {'heatmap': ..., 'parsing': ...}
"""
features = self.shared_encoder(x)
# 上采样到原始LR尺寸
heatmap = self.heatmap_head(features)
heatmap = F.interpolate(
heatmap,
size=x.shape[2:],
mode='bilinear',
align_corners=False
)
parsing = self.parsing_head(features)
parsing = F.interpolate(
parsing,
size=x.shape[2:],
mode='bilinear',
align_corners=False
)
return {'heatmap': heatmap, 'parsing': parsing}
2.4 损失函数设计(loss.py)
FSRNet采用多任务损失,包含多个分量:
class FSRNetLoss(nn.Module):
"""
FSRNet复合损失函数
L_total = L_recon + λ1*L_prior + λ2*L_perceptual + λ3*L_adversarial
"""
def __init__(self, lambda_prior=0.1, lambda_perceptual=0.01, lambda_adv=0.001):
super(FSRNetLoss, self).__init__()
self.lambda_prior = lambda_prior
self.lambda_perceptual = lambda_perceptual
self.lambda_adv = lambda_adv
# 重建损失(L1)
self.recon_loss = nn.L1Loss()
# 先验损失(关键点和解析图)
self.prior_loss = nn.MSELoss() # 热图用MSE
# 感知损失(使用预训练VGG)
self.perceptual_loss = PerceptualLoss()
# 对抗损失(可选)
self.adversarial_loss = nn.BCEWithLogitsLoss()
def forward(self, sr_output, hr_target, prior_info, prior_pred):
"""
计算总损失
Args:
sr_output: 网络输出的SR图像
hr_target: 真实HR图像
prior_info: 真实先验信息
prior_pred: 预测的先验信息
"""
# 1. 重建损失
loss_recon = self.recon_loss(sr_output, hr_target)
# 2. 先验损失(预测准确性)
loss_prior = 0
if prior_info is not None and prior_pred is not None:
# 关键点热图损失
loss_heatmap = self.prior_loss(
prior_pred['heatmap'],
prior_info['heatmap']
)
# 解析图损失(交叉熵)
loss_parsing = F.cross_entropy(
prior_pred['parsing'],
prior_info['parsing'].argmax(dim=1) # 转为类别标签
)
loss_prior = loss_heatmap + loss_parsing
# 3. 感知损失
loss_perceptual = self.perceptual_loss(sr_output, hr_target)
# 4. 对抗损失(训练判别器时使用)
loss_adversarial = 0 # 在判别器训练中计算
# 总损失
total_loss = (
loss_recon +
self.lambda_prior * loss_prior +
self.lambda_perceptual * loss_perceptual +
self.lambda_adv * loss_adversarial
)
return {
'total': total_loss,
'recon': loss_recon,
'prior': loss_prior,
'perceptual': loss_perceptual,
}
class PerceptualLoss(nn.Module):
"""感知损失:基于VGG19的特征空间距离"""
def __init__(self, layer_names=['relu3_3', 'relu4_3']):
super(PerceptualLoss, self).__init__()
self.vgg = VGG19Features(layer_names)
# 冻结VGG参数
for param in self.vgg.parameters():
param.requires_grad = False
def forward(self, sr, hr):
# 提取特征
sr_features = self.vgg(sr)
hr_features = self.vgg(hr)
# 计算特征距离
loss = 0
for sr_feat, hr_feat in zip(sr_features, hr_features):
loss += F.mse_loss(sr_feat, hr_feat)
return loss / len(sr_features)
2.5 训练流程(train.py)
def train_one_epoch(model, prior_net, dataloader, optimizer, loss_fn, device):
"""
单轮训练函数
"""
model.train()
prior_net.train()
total_loss = 0
for batch_idx, (lr_images, hr_images, priors) in enumerate(dataloader):
lr_images = lr_images.to(device)
hr_images = hr_images.to(device)
# 1. 预测先验信息
prior_pred = prior_net(lr_images)
# 2. 超分辨率重建
sr_output = model(lr_images, prior_pred)
# 3. 计算损失
losses = loss_fn(sr_output, hr_images, priors, prior_pred)
# 4. 反向传播
optimizer.zero_grad()
losses['total'].backward()
optimizer.step()
total_loss += losses['total'].item()
if batch_idx % 100 == 0:
print(f"Batch {batch_idx}: Loss={losses['total'].item():.4f}")
return total_loss / len(dataloader)
def train(model, prior_net, train_loader, val_loader, config):
"""
完整训练流程
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
prior_net = prior_net.to(device)
# 优化器
optimizer = torch.optim.Adam(
[
{'params': model.parameters()},
{'params': prior_net.parameters(), 'lr': config.lr * 0.1} # 先验网络学习率小一些
],
lr=config.lr,
betas=(0.9, 0.999)
)
# 学习率调度器
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=5
)
# 损失函数
loss_fn = FSRNetLoss(
lambda_prior=config.lambda_prior,
lambda_perceptual=config.lambda_perceptual
)
best_psnr = 0
for epoch in range(config.num_epochs):
# 训练
train_loss = train_one_epoch(
model, prior_net, train_loader, optimizer, loss_fn, device
)
# 验证
val_psnr = validate(model, prior_net, val_loader, device)
# 调度学习率
scheduler.step(-val_psnr) # PSNR越高越好
# 保存最佳模型
if val_psnr > best_psnr:
best_psnr = val_psnr
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'prior_net_state_dict': prior_net.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'psnr': val_psnr
}, 'best_model.pth')
print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, Val PSNR={val_psnr:.2f}")
2.6 推理代码实现
class FSRNetInference:
"""FSRNet推理封装类"""
def __init__(self, model_path, prior_net_path=None, device='cuda'):
"""
初始化推理器
Args:
model_path: FSRNet模型路径
prior_net_path: 先验预测网络路径(若为None则使用纯FSRNet)
"""
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
# 加载FSRNet
self.model = FSRNet(scale_factor=4)
checkpoint = torch.load(model_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.model.eval()
# 加载先验网络(可选)
self.prior_net = None
if prior_net_path:
self.prior_net = PriorNet()
prior_checkpoint = torch.load(prior_net_path, map_location=self.device)
self.prior_net.load_state_dict(prior_checkpoint['model_state_dict'])
self.prior_net.to(self.device)
self.prior_net.eval()
# 图像预处理
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
def preprocess(self, image_path):
"""预处理LR图像"""
image = Image.open(image_path).convert('RGB')
# 确保尺寸是4的倍数(便于下采样)
w, h = image.size
w = w - (w % 4)
h = h - (h % 4)
image = image.crop((0, 0, w, h))
# 转换为tensor
lr_tensor = self.transform(image).unsqueeze(0) # (1, 3, H, W)
return lr_tensor.to(self.device), image
def postprocess(self, sr_tensor):
"""后处理:tensor转PIL图像"""
sr_tensor = sr_tensor.squeeze(0).cpu()
sr_image = transforms.ToPILImage()(sr_tensor)
return sr_image
def __call__(self, image_path, use_prior=True):
"""
推理入口
Args:
image_path: LR图像路径
use_prior: 是否使用先验信息
Returns:
HR图像 (PIL)
"""
lr_tensor, lr_image = self.preprocess(image_path)
with torch.no_grad():
if use_prior and self.prior_net is not None:
# 预测先验
prior_pred = self.prior_net(lr_tensor)
# 超分辨率
sr_tensor = self.model(lr_tensor, prior_pred)
else:
# 纯FSRNet(性能较差)
sr_tensor = self.model(lr_tensor)
return self.postprocess(sr_tensor)
# 使用示例
if __name__ == '__main__':
# 初始化推理器
fsrnet = FSRNetInference(
model_path='checkpoints/fsrnet_best.pth',
prior_net_path='checkpoints/prior_net_best.pth'
)
# 推理单张图像
hr_image = fsrnet('data/lr_face.jpg')
hr_image.save('output/hr_face.jpg')
print(f"超分辨率完成!输出尺寸: {hr_image.size}")
三、数据准备与处理
3.1 数据集格式要求
FSRNet需要成对的LR-HR图像以及先验信息:
class FSRNetDataset(torch.utils.data.Dataset):
"""FSRNet数据集:需要LR、HR、关键点、解析图"""
def __init__(self, data_dir, split='train', scale_factor=4):
"""
Args:
data_dir: 数据根目录
split: 'train' or 'val'
scale_factor: 超分辨率倍数
"""
self.scale_factor = scale_factor
self.data_dir = data_dir
# 数据文件列表
self.image_list = []
with open(os.path.join(data_dir, f'{split}.txt'), 'r') as f:
for line in f:
self.image_list.append(line.strip())
# 数据增强
self.transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(brightness=0.1, contrast=0.1),
])
def __len__(self):
return len(self.image_list)
def __getitem__(self, idx):
base_name = self.image_list[idx]
# 1. 加载HR图像
hr_path = os.path.join(self.data_dir, 'HR', f'{base_name}.png')
hr_image = Image.open(hr_path).convert('RGB')
# 2. 生成LR图像(双三次插值下采样)
w, h = hr_image.size
lr_w, lr_h = w // self.scale_factor, h // self.scale_factor
lr_image = hr_image.resize((lr_w, lr_h), Image.BICUBIC)
# 3. 加载先验信息
# 关键点热图
heatmap_path = os.path.join(self.data_dir, 'heatmap', f'{base_name}.npy')
heatmap = np.load(heatmap_path) # (68, H, W)
# 解析图
parsing_path = os.path.join(self.data_dir, 'parsing', f'{base_name}.png')
parsing = Image.open(parsing_path)
parsing = np.array(parsing) # (H, W)
# 转换为one-hot
parsing_onehot = np.eye(19)[parsing] # (H, W, 19)
parsing_onehot = parsing_onehot.transpose(2, 0, 1) # (19, H, W)
# 4. 数据增强(仅训练)
if self.transform and split == 'train':
# 对HR和LR同步增强
seed = np.random.randint(0, 2**32)
random.seed(seed)
torch.manual_seed(seed)
hr_image = self.transform(hr_image)
random.seed(seed)
torch.manual_seed(seed)
lr_image = self.transform(lr_image)
# 5. 转换为Tensor
to_tensor = transforms.ToTensor()
lr_tensor = to_tensor(lr_image)
hr_tensor = to_tensor(hr_image)
# 归一化到[-1,1]
lr_tensor = lr_tensor * 2 - 1
hr_tensor = hr_tensor * 2 - 1
# 先验信息tensor
heatmap_tensor = torch.from_numpy(heatmap).float()
parsing_tensor = torch.from_numpy(parsing_onehot).float()
prior_info = {
'heatmap': heatmap_tensor,
'parsing': parsing_tensor
}
return lr_tensor, hr_tensor, prior_info
3.2 先验信息生成
在训练前,需要使用外部工具生成先验信息:
import cv2
import dlib
from scipy.ndimage import gaussian_filter
def generate_prior_info(hr_image_path, output_dir):
"""
生成训练所需的先验信息
Args:
hr_image_path: HR图像路径
output_dir: 输出目录
"""
# 加载HR图像
image = cv2.imread(hr_image_path)
h, w = image.shape[:2]
# 1. 生成关键点热图
# 使用dlib检测68个关键点
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor("shape_predictor_68_face_landmarks.dat")
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
faces = detector(gray)
if len(faces) == 0:
print(f"未检测到人脸: {hr_image_path}")
return
# 获取关键点
landmarks = predictor(gray, faces[0])
keypoints = []
for i in range(68):
x = landmarks.part(i).x
y = landmarks.part(i).y
keypoints.append((x, y))
# 生成高斯热图
heatmap = np.zeros((68, h, w), dtype=np.float32)
for i, (x, y) in enumerate(keypoints):
# 在关键点位置生成2D高斯分布
xx, yy = np.meshgrid(np.arange(w), np.arange(h))
dist = (xx - x)**2 + (yy - y)**2
sigma = 3.0 # 高斯核标准差
gaussian = np.exp(-dist / (2 * sigma**2))
heatmap[i] = gaussian
# 2. 生成解析图
# 使用人脸解析模型(如BiSeNet)
parsing = generate_parsing_map(image) # 返回(H,W)的类别索引图
# 保存
base_name = os.path.splitext(os.path.basename(hr_image_path))[0]
np.save(os.path.join(output_dir, 'heatmap', f'{base_name}.npy'), heatmap)
cv2.imwrite(os.path.join(output_dir, 'parsing', f'{base_name}.png'), parsing)
print(f"生成先验信息完成: {base_name}")
def generate_parsing_map(image):
"""
生成人脸解析图(伪代码,需接入实际模型)
"""
# 实际项目中使用预训练的人脸解析模型
# 例如: https://github.com/zllrunning/face-parsing.PyTorch
# 这里仅示意
from face_parsing import BiSeNet
model = BiSeNet(n_classes=19)
model.load_state_dict(torch.load('bisenet.pth'))
model.eval()
# 预处理
to_tensor = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_tensor = to_tensor(image).unsqueeze(0)
with torch.no_grad():
parsing = model(input_tensor)[0] # (1, 19, H, W)
parsing = parsing.argmax(dim=1).squeeze(0).cpu().numpy()
return parsing
四、关键实现细节与技巧
4.1 训练策略
4.1.1 分阶段训练
def two_stage_training():
"""
两阶段训练策略
"""
# 阶段1:仅训练FSRNet,固定先验网络(如果已预训练)
# 阶段2:联合微调两个网络
# 阶段1:预训练先验网络
optimizer_prior = torch.optim.Adam(prior_net.parameters(), lr=1e-3)
for epoch in range(50):
# 仅计算先验损失
prior_pred = prior_net(lr_images)
loss = prior_loss(prior_pred, priors)
loss.backward()
optimizer_prior.step()
# 阶段2:联合训练
optimizer_joint = torch.optim.Adam([
{'params': model.parameters()},
{'params': prior_net.parameters(), 'lr': 1e-4}
], lr=1e-3)
for epoch in range(100):
# 联合计算所有损失
prior_pred = prior_net(lr_images)
sr_output = model(lr_images, prior_pred)
loss = loss_fn(sr_output, hr_images, priors, prior_pred)
loss.backward()
optimizer_joint.step()
4.1.2 混合精度训练
from torch.cuda.amp import autocast, GradScaler
def train_mixed_precision(model, prior_net, dataloader, optimizer, loss_fn, device):
scaler = GradScaler()
for lr_images, hr_images, priors in dataloader:
lr_images = lr_images.to(device)
hr_images = hr_images.to(device)
optimizer.zero_grad()
# 混合精度前向传播
with autocast():
prior_pred = prior_net(lr_images)
sr_output = model(lr_images, prior_pred)
losses = loss_fn(sr_output, hr_images, priors, prior_pred)
# 缩放梯度并反向传播
scaler.scale(losses['total']).backward()
scaler.step(optimizer)
scaler.update()
4.2 推理优化
4.2.1 模型量化(INT8)
def quantize_model(model):
"""模型量化以加速推理"""
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
# 校准(使用少量数据)
# torch.quantization.convert(model, inplace=True)
return model
4.2.2 TensorRT加速
import torch_tensorrt
def convert_to_tensorrt(model, input_shape):
"""转换为TensorRT引擎"""
# 转换为TorchScript
traced_model = torch.jit.trace(model, torch.randn(1, 3, *input_shape).cuda())
# 编译TensorRT
trt_model = torch_tensorrt.compile(
traced_model,
inputs=[torch.randn(1, 3, *input_shape).cuda()],
enabled_precisions={torch.float16}, # FP16
workspace_size=1 << 30,
truncate_long_and_double=True
)
return trt_model
4.3 评估指标
def calculate_metrics(sr_image, hr_image):
"""
计算超分辨率评估指标
"""
import lpips
# PSNR
psnr = peak_signal_to_noise_ratio(sr_image, hr_image)
# SSIM
ssim = structural_similarity(sr_image, hr_image, multichannel=True)
# LPIPS(感知相似度)
lpips_loss = lpips.LPIPS(net='vgg')
lpips_score = lpips_loss(sr_image, hr_image)
return {
'PSNR': psnr,
'SSIM': ssim,
'LPIPS': lpips_score.item()
}
def peak_signal_to_noise_ratio(sr, hr):
"""PSNR计算"""
mse = np.mean((sr - hr) ** 2)
if mse == 0:
return 100
return 20 * np.log10(255.0 / np.sqrt(mse))
def structural_similarity(sr, hr, multichannel=True):
"""SSIM计算"""
from skimage.metrics import structural_similarity as ssim
return ssim(sr, hr, multichannel=multichannel, channel_axis=2 if multichannel else None)
五、常见问题与解决方案
5.1 训练不稳定
问题:损失震荡,不收敛 解决方案:
- 使用梯度裁剪:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - 降低学习率:初始lr设为1e-4
- 使用学习率warmup:前几个epoch线性增加学习率
5.2 先验网络预测不准
问题:先验信息质量差,反而影响重建效果 解决方案:
- 预训练先验网络:单独训练PriorNet直到收敛
- 数据增强:对LR图像进行抖动,提升鲁棒性
- 损失权重调整:增大先验损失权重(λ_prior)
5.3 显存不足
问题:大batch size导致OOM 解决方案:
- 梯度累积:
accumulation_steps = 4 - 混合精度训练
- 减小网络通道数:
num_channels=32
5.4 推理速度慢
问题:单张图像推理时间过长 解决方案:
- 使用ONNX导出并优化
- TensorRT加速(可提升3-5倍)
- 批量推理:一次处理多张图像
六、FSRNet的改进与变体
6.1 FSRNet++
2019年提出的改进版本,主要改进:
- 注意力机制:在先验融合中引入通道注意力
- 更精细的多尺度特征:使用U-Net结构替代简单编解码
class FSRNetPlus(nn.Module):
"""FSRNet++:引入注意力机制"""
def __init__(self):
super().__init__()
self.encoder = UNetEncoder() # 改用U-Net
self.attention_fusion = AttentionPriorFusion()
self.decoder = UNetDecoder()
6.2 与其他模型的结合
- ESRGAN + FSRNet:引入ESRGAN的判别器
- ARCNN + FSRNet:先压缩再重建,处理极端低分辨率
七、总结
FSRNet通过先验知识引导的创新思路,在人脸超分辨率领域取得了突破性进展。其核心价值在于:
- 领域知识融合:将人脸结构先验嵌入网络
- 多任务学习:联合优化重建与先验预测
- 渐进式重建:coarse-to-fine的生成策略
在实际应用中,需要注意:
- 数据准备:高质量的先验信息是关键
- 训练技巧:分阶段训练、混合精度
- 推理优化:模型量化、TensorRT加速
通过本文的代码解析,读者应能掌握FSRNet的完整实现流程,并可根据实际需求进行定制化改进。
参考文献:
- Chen, Y., et al. “FSRNet: Face Super-Resolution Network with Prior Knowledge.” CVPR 2018.
- 代码参考:https://github.com/xyfJASON/FSRNet
- 先验生成工具:dlib, BiSeNet
