掘金 人工智能 08月19日
显微镜图像处理【优化】- 使用图像风格迁移技术放大图像细节
index_new5.html
../../../zaker_core/zaker_tpl_static/wap/tpl_guoji1.html

 

本文探讨了在显微镜图像细节增强领域的深度学习优化方法。考虑到生产环境的限制和高分辨率图像处理的挑战,文章推荐将大尺寸图像分割为512*512的小块进行处理,以降低显存消耗和提高效率。作者介绍了使用pix2pix模型进行图像风格迁移,并详细阐述了从PyTorch模型转换为ONNX格式、图像拆分与拼合、以及最终输出优化图的完整流程。文章强调了这种方法在统一图像风格和展现细节方面的潜力,但同时严正警告在生产环境中不推荐使用深度学习进行图像优化。

🔬 **显微镜图像优化挑战与解决方案**:显微镜拍摄的图像常面临光照不均和风格不统一的问题。为解决这些挑战,文章提出一种基于深度学习的图像优化方案,核心在于将高分辨率图像分割成512*512的小块进行处理,以克服显存限制和提高训练、部署效率,从而在不同光照条件下都能良好地展现图像细节。

🔄 **Pix2Pix实现图像风格迁移**:文章选用了pix2pix模型进行图像风格迁移,该技术能根据预设风格将输入图像的每一个像素进行转换,实现图像风格的统一。通过将复杂的图像风格迁移任务分解为小块处理,可以有效降低对计算资源的要求,使其更适用于高分辨率显微镜图像的处理。

⚙️ **ONNX模型部署流程**:为实现模型的实际应用,文章详细介绍了将PyTorch模型转换为ONNX格式的步骤。ONNX作为一种开放的模型交换格式,便于跨平台部署。通过ONNX,模型可以被高效地加载和运行,为后续的图像处理流程奠定基础。

🧩 **图像预处理与后处理**:整个优化流程包含了精细的图像预处理和后处理。这包括利用Labelme标注文件生成图像掩码,对掩码进行膨胀处理以避免信息丢失;将图像分割成小块(patch),对每个patch进行模型预测优化;最后,利用重叠块的平滑拼接技术将优化后的patch还原成完整的图像,并与原图进行融合,确保细节的连续性和自然性。

⚠️ **生产环境的审慎态度**:文章在多次强调其作为一种技术实践和个人练习的性质,并郑重警告,不推荐在生产环境中使用深度学习技术进行显微镜图像的优化,这表明作者对技术在实际应用中的风险和局限性有清醒的认识。

1. 前景提要

首先:在生产环境不推荐使用任何深度学习技术对图像进行优化!在生产环境不推荐使用任何深度学习技术进行图像优化!在生产环境不推荐使用任何深度学习技术进行图像优化!本篇文章单纯记录一种显微镜细节增强领域的图像优化手段。仅供参考!!

根据显微镜不同的光源条件,拍摄的图片可能有明有暗。要怎么做到不管是明亮的图像还是灰暗的图像,他的图像细节都能够相对非常良好的展现出来?

我们会遇到各种各样的图像风格,如何让这些拍摄图像的风格趋近于统一成了一个值得研究的业务问题。

但是,图像风格迁移技术是一项非常耗费显存的技术,而显微镜拍摄的图像又要求有极高的分辨率。如果说我们将一张1600*1200的正常百倍镜图像传入模型进行训练的话。你光预测耗时就会非常的长。你的服务器成本会很高!更别说端到端客户端部署了。

所以我推荐将一张图像切分成多张512*512分辨率的图像,这样方便训练,集成和部署。

图像风格迁移技术

这里我们使用pix2pix(pytorch)来训练我们的图像分格迁移模型:github.com/junyanz/pyt…

【图像来自pix2pix-pytorch的github首页】

上面这张图像简单说明了pix2pix是干啥的,说高大上点叫做图像风格迁移,说简单点就是将图片中的每一个像素根据预设的风格转换成新的一张图片。

训练

[训练过程省略...]

部署

1. 先将pytorch模型转换为onnx文件。这是部署的第一步!

import torchimport torch.onnxfrom optim_algorithm.optim1.index import imageEnhanceModeldef convert_to_onnx(pytorch_model, dummy_input, output_path="model.onnx"):    pytorch_model.eval()    # 导出为ONNX    torch.onnx.export(        pytorch_model,  # 模型        dummy_input,  # 示例输入        output_path,  # 输出文件路径        export_params=True,  # 导出训练参数        opset_version=11,  # ONNX算子集版本        do_constant_folding=True,  # 优化常量折叠        input_names=['input'],  # 输入节点名称        output_names=['output'],  # 输出节点名称        dynamic_axes={  # 动态维度配置            'input': {0: 'batch_size'},            'output': {0: 'batch_size'}        }    )    print(f"模型已成功导出到 {output_path}")# 创建动态尺寸的dummy_input# 假设输入格式为 [batch, channel, height, width]batch_size = 1  # 可以是任意值channels = 3  # RGB图像dummy_input = torch.randn(1, 3, 512, 512)convert_to_onnx(imageEnhanceModel.model,dummy_input)

这样我们就得到了一个图像优化的onnx文件。再传入netron.app/ 这个网站分析onnx结构。知道它的输入是什么。输出是什么。

【netron可视化onnx节点】

从上图我们可以得知:

好了,现在我们就得到了一个可用的图像优化onnx了。现在我们来打通从图像输入->图像拆分拼图(每一个拼图为512*512,分辨率尺寸不够就填充白背景)->传入onnx模型优化->优化图像还原拼图->输出优化图

这样一个全过程!

全过程

图像前处理

from .index import optimProcessimport numpy as npimport cv2def enhance_image(mat,shapes):    # 2. 更具shapes绘制mask    mask = np.zeros(mat.shape[:2], np.uint8)+255    for shape in shapes:        points = np.array(shape["points"]).astype(np.int32)        mask = cv2.drawContours(mask, points[np.newaxis, :, np.newaxis, :], 0, (0, 0, 0), cv2.FILLED)    return optimProcess(mat,mask)

上面是图像的处理全过程,其中mat为输入的原图图像,shapes为labelme文件的shapes(labelme是啥不过多赘述)

optimProcess【图像处理】

def optimProcess(imageData,imageMask):    # 膨胀mask,防止分割范围过小缺失信息    kernel = np.ones((5, 5), np.uint8)    imageMask_dilate = copy.deepcopy(255 - cv2.dilate(255 - imageMask, kernel, iterations=2))    index2 = np.where(imageMask_dilate == 0)    if len(index2) != 0 and len(index2[1]) != 0:        x1 = np.min(index2[1])        x2 = np.max(index2[1])        y1 = np.min(index2[0])        y2 = np.max(index2[0])        # 提取roi区域进行优化        targetImageArray = imageData[y1: y2, x1: x2]        imageMask_roi = imageMask_dilate[y1: y2, x1: x2]        imageOptim = optimAllImage(targetImageArray,imageMask_roi)        # 得到优化图拼接回原图尺寸        imageOptim_all = np.zeros(imageData.shape, np.uint8) + 255        if len(imageOptim_all.shape) == 3 and len(imageOptim.shape) == 2:            imageOptim = cv2.cvtColor(imageOptim, cv2.COLOR_GRAY2BGR)        elif len(imageOptim_all.shape) == 2 and len(imageOptim.shape) == 3:            imageOptim = cv2.cvtColor(imageOptim, cv2.COLOR_BGR2GRAY)        imageOptim_all[y1: y2, x1: x2] = imageOptim        imageOptim_all = cv2.add(imageMask_dilate, imageOptim_all)        return imageOptim_all    else:        return imageData

optimAllArray 【将根据mask裁剪的roi图像做拼图】

def optimAllImage(img,mask):    if len(img.shape) != 2:        image_array = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)    if len(mask.shape) != 2:        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)    img = cv2.add(img, mask)    patchs,_,r_size = img2patchs(img,patch_size=(block_size, block_size), overlap_size=(overlap, overlap))    N = len(patchs)    M = len(patchs[0])    process_patchs = []    for n in range(N):        blocks = []        for m in range(M):            patch = patchs[n][m]            patch = predict(patch)            if len(patch.shape) != 2:                patch = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)            blocks.append(patch.astype(np.uint8))        process_patchs.append(blocks)    recovery_img = patchs2img(process_patchs, r_size, (overlap, overlap))    return recovery_img

单独优化每一块拼图

img2patchs

def img2patchs(img,patch_size=(block_size, block_size), overlap_size=(overlap, overlap)):    h, w = img.shape    ph, pw = patch_size    oh, ow = overlap_size    r_h = (h - ph) % (ph - oh)    r_w = (w - pw) % (pw - ow)    target_w, target_h = w, h    if not (h >= ph > oh and w >= pw > ow):        return [[img]], (target_h, target_w), (0, 0)    N = math.ceil((target_h - ph) / (ph - oh)) + 1    M = math.ceil((target_w - pw) / (pw - ow)) + 1    patchs_all = []    for n in range(N):        patchs_row = []        for m in range(M):            if n == N - 1:                ph_start = target_h - ph            else:                ph_start = n * (ph - oh)            if m == M - 1:                pw_start = target_w - pw            else:                pw_start = m * (pw - ow)            patch = img[ph_start:(ph_start + ph), pw_start:(pw_start + pw)]            patchs_row.append(patch)        patchs_all.append(patchs_row)    return patchs_all, (target_h, target_w), (r_h, r_w)

predict

def predict(image):    # 将image从bgr转换为rgb    image = process_input_image(image,512)    image = np.array(image)    mean = (127,127,127)    std = (127,127,127)    img_tensor = (image - mean) / std    img_tensor = img_tensor.astype('float32')    img_tensor = np.transpose(img_tensor, (2, 0, 1))    img_tensor = np.expand_dims(img_tensor, axis=0)    ort_inputs ={"input": img_tensor}    output = ort_session.run(['output'], ort_inputs)    output = convert_to_numpy(output)    output = clip_image(output, image, 512)    h1, w1 = image.shape[:2]    h2, w2 = output.shape[:2]    pad_h = h1 - h2    pad_w = w1 - w2    if pad_h > 0:        output = np.pad(output, ((0, pad_h), (0, 0), (0, 0)), mode='constant', constant_values=255)    else:        output = output[:h1, :, :]    if pad_w > 0:        output = np.pad(output, ((0, 0), (0, pad_w), (0, 0)), mode='constant', constant_values=255)    else:        output = output[:, :w1, :]    return output

patchs2img【还原拼图】

def patchs2img(patchs, r_size, overlap_size=(20, 20)):    N = len(patchs)    M = len(patchs[0])    # print("N:{}, M:{}".format(N, M))    oh, ow = overlap_size    patch_shape = patchs[0][0].shape    ph, pw = patch_shape[:2]    r_h, r_w = r_size    mode = 'GRAY' if len(patch_shape) == 2 else 'RGB'    if N == 1 and M == 1:        return_img = patchs[0][0]        return return_img# if mode == 'RGB' else cv2.cvtColor(return_img, cv2.COLOR_GRAY2RGB)    row_imgs = []    for n in range(N):        row_img = patchs[n][0]# if mode == 'RGB' else cv2.cvtColor(patchs[n][0], cv2.COLOR_GRAY2RGB)        for m in range(1, M):            if m == M - 1 and r_w != 0:                ow_new = pw - r_w            else:                ow_new = ow            # ow_new = ow            patch = patchs[n][m]# if mode == 'RGB' else cv2.cvtColor(patchs[n][m], cv2.COLOR_GRAY2RGB)            # print(mode, patch.shape)            h, w = row_img.shape[:2]            new_w = w + pw - ow_new            big_row_img = np.zeros((h, new_w), dtype=np.uint8)            big_row_img[:, :w - ow_new] = row_img[:, :w - ow_new]            big_row_img[:, w:] = patch[:, ow_new:]            overlap_row_01 = row_img[:, w - ow_new:]            overlap_row_02 = patch[:, :ow_new]            # get weight            weight = vertical_grad(overlap_row_01.shape, 0, 255, mode='w') / 255            overlap_row = (overlap_row_01 * (1 - weight)).astype(np.uint8) + (overlap_row_02 * weight).astype(np.uint8)            big_row_img[:, w - ow_new:w] = overlap_row            row_img = big_row_img        row_imgs.append(row_img)    column_img = row_imgs[0]    for i in range(1, N):        if i == N - 1 and r_h != 0:            oh_new = ph - r_h        else:            oh_new = oh        # oh_new = oh        row_img = row_imgs[i]        h, w = column_img.shape[:2]        new_h = h + ph - oh_new        big_column_img = np.zeros((new_h, w), dtype=np.uint8)        big_column_img[:h - oh_new, :] = column_img[:h - oh_new, :]        big_column_img[h:, :] = row_img[oh_new:, :]        overlap_column_01 = column_img[h - oh_new:, :]        overlap_column_02 = row_img[:oh_new, :]        # get weight        weight = vertical_grad(overlap_column_01.shape, 0, 255, mode='h') / 255        overlap_column = (overlap_column_01 * (1 - weight)).astype(np.uint8) + (overlap_column_02 * weight).astype(np.uint8)        big_column_img[h - oh_new:h, :] = overlap_column        column_img = big_column_img    return column_img

优化效果

【原图】

【优化图】

再次警告!不推荐使用深度学习技术在显微成像领域做图像优化,本篇文章纯粹记录本人练习onnx模型部署有感而发!不推荐在生产环境用深度学习技术做图像优化!

Fish AI Reader

Fish AI Reader

AI辅助创作,多种专业模板,深度分析,高质量内容生成。从观点提取到深度思考,FishAI为您提供全方位的创作支持。新版本引入自定义参数,让您的创作更加个性化和精准。

FishAI

FishAI

鱼阅,AI 时代的下一个智能信息助手,助你摆脱信息焦虑

联系邮箱 441953276@qq.com

相关标签

显微镜图像 图像优化 深度学习 pix2pix ONNX
相关文章