AI自动上色技术为百年前拍摄的黑白影片实现自动上色

看看科技观察芯情 2024-01-04 16:38:02

不久前一项人工智能技术在网上引起一阵喧嚣。有人使用AI自动上色技术为百年前拍摄的黑白影片实现自动上色,这使得百年前的生活场景再次活灵活现的展现在现代人的眼前。AI技术让现代人有机会领略到前人的生活风貌。本节就黑白图片自动上色的技术原理进行详细探讨。

算法基本原理

图片上色是AI算法领域一大难题,难就难在于它没有“标准答案”。因为一个物体其颜色存在多种合理的可能,例如一条裙子可以将其绘制成红色,也可以绘制成蓝色,正是因为这种多样性的存在,使得算法的设计要思考的因素过多,难度复杂。

因此很多自动上色技术需要人的干预。特别是算法无法确定某个具体颜色时,它需要人帮助它选择,因此以往算法很难做到全面自动化,也就是完全排除人的干预。本节要研究的算法其特性就在于能完全避免人的干预,实现自动化上色,并且上色质量能得以保证。

先看算法的逻辑框架,如图1所示。

图1 网络架构图

从图1看到,这次要构造的网络与以往有所区别。鉴别者需要同时输入两张图片,一张是黑白图片,一张是原来的色彩图片,这种网络也叫条件网络,黑白图片也叫前置条件。鉴别者网络在识别彩色图片时要基于黑白图片所提供的信息。

算法的基本思路是,先将彩色图片与它对应的黑白图片配对,输入鉴别者网络,让网络识别图片中物体在真实情况下的色彩规律。然后将黑白图片输入生成者网络,让它产生图片对应的彩色模式。

将黑白图片和生成者网络产生的彩色图片同时输入鉴别者网络。由于后者已经训练得能识别各种物体对应的颜色规律,于是它能识别出生成者网络给物体上色时产生的错误,算法将错误信息返回给生成者网络进行修改和纠正,于是在不断循环中,生成者网络上色技巧越来越好,直到鉴别者网络很难识别为止。

网络内部结构设计

生成者网络内部结构采用前面使用到的U-net模型,它的构造如图2所示。

图2 生成者网络结构图

从图2看到,生成者网络接收256*256的灰度图作为输入,然后使用8个卷积网络去识别输入图像中的物体规律,然后使用8个反卷积网络根据识别的信息对图像进行复原,在复原过程中会设置每个像素对应色彩从而实现图片上色的功能。

它使用到前面描述过的U-net架构,也就是对应的卷积网络将它识别的信息直接传递给其对应的反卷积网络,这样后者就能拿到卷积网络识别时的图片信息,掌握图片中对应物体的形状特性从而提升像素点颜色赋值的准确性。

接下来看鉴别者网络的内部结构,如图3所示。

图3 鉴别者网络结构图

从图3看到,在训练鉴别者网络时,算法先将图片对应的灰度图和原来的彩色图结合成一个二维数组,然后鉴别者网络使用多层卷积网络识别结合后的数组,由此识别出图中物体与其原有颜色所存在的逻辑联系。

最后它输出一个结果,如果彩色图是灰度图的正确着色,它返回的值就尽可能大,如果灰度图与着色的彩色图不匹配,它返回值要尽可能小。于是算法训练鉴别者网络在接收灰度图与正确的彩色图配对时返回的值尽可能大;当灰度图与生成者网络构造的彩色图配对时,其返回数值尽可能小。

在训练生成者网络时,要让其生成的彩色图与对应灰度图结合输入到鉴别者网络后,后者输出的值尽可能大。当生成者网络构造的彩色图片输入到鉴别者网络所得结果越大,说明生成者上色的图片对应质量就越好。

代码实现

本节使用代码实现前面两节描述的算法原理以及网络结构。同时驱动训练流程,使得生成者网络具备对给定灰度图正确上色的能力。首先使用代码构造生成者和鉴别者网络:

class Generator(tf.keras.Model): def __init__(self, encoder_kernel, decoder_kernel): super(Generator, self).__init__() self.encoder_kernels = encoder_kernel#对应卷积层参数 self.decoder_kernels = decoder_kernel #对应反卷积层参数 self.kernel_size = 4 self.output_channels = 3#最终输出RGB颜色图像图像 self.left_size_layers = [] self.right_size_layers = [] self.last_layers = [] self.create_network() def create_network(self): #构建生成者网络 for index, kernel in enumerate(self.encoder_kernels): #设立卷积层识别输入图像规律 down_sample_layers = [] down_sample_layers.append(tf.keras.layers.Conv2D( kernel_size = self.kernel_size, filters = kernel[0], strides = kernel[1], padding = 'same' )) down_sample_layers.append(tf.keras.layers.BatchNormalization()) down_sample_layers.append(tf.keras.layers.LeakyReLU()) self.left_size_layers.append(down_sample_layers) for index, kernel in enumerate(self.decoder_kernels):#设立反卷积层,实现像素点颜色赋值 up_sample_layers = [] up_sample_layers.append(tf.keras.layers.Conv2DTranspose( kernel_size = self.kernel_size, filters = kernel[0], strides = kernel[1], padding = 'same' ))

self.discriminator_layers.append(tf.keras.layers.BatchNormalization()) self.discriminator_layers.append(tf.keras.layers.LeakyReLU()) self.discriminator_layers.append(tf.keras.layers.Conv2D( kernel_size = 4, filters = 1, strides = 1, padding = 'same' ))#输出表明着色正确性的数值 def call(self, x): x = tf.convert_to_tensor(x, dtype = tf.float32) for layer in self.discriminator_layers: x = layer(x) return x def create_variables(self):#生成鉴别者网络内部参数 dummy1 = np.zeros((1, 256, 256, 1)) dummy2 = np.zeros((1, 256, 256, 3)) x = np.concatenate((dummy1, dummy2), axis = 3) self.call(x)

接下来需要完成数据预处理代码,它包括图片的加载,彩色图片转换为黑白图片,LAB和RGB图片格式互换等。为了提升着色效果,算法让生成者网络构造的图片遵循LAB格式,这是因为图片效果好坏非常依赖于色彩的亮度。

RGB图片格式无法表达色彩亮度,后LAB可以,因此训练生成者网络生成LAB格式图片就能让网络把握图片中色彩亮度,这样能有效提升着色效果,当有了LAB格式图片后,在进行展示时,代码再将其转换为RGB格式,下面展示相关实现代码:

import osimport sysimport timeimport randomimport pickleimport numpy as npfrom PIL import Imageimport matplotlib.pyplot as pltdef stitch_images(grayscale, original, pred):#将灰度图,对应的彩色图,以及生成者网络上色后的结果“缝合”在一起 gap = 5 width, height = original[0][:, :, 0].shape img_per_row = 2 if width > 200 else 4 img = Image.new('RGB', (width * img_per_row * 3 + gap * (img_per_row - 1), height * int(len(original) / img_per_row))) grayscale = np.array(grayscale).squeeze() original = np.array(original) pred = np.array(pred) for ix in range(len(original)): xoffset = int(ix % img_per_row) * width * 3 + int(ix % img_per_row) * gap yoffset = int(ix / img_per_row) * height im1 = Image.fromarray(grayscale[ix]) im2 = Image.fromarray(original[ix]) im3 = Image.fromarray((pred[ix] * 255).astype(np.uint8)) img.paste(im1, (xoffset, yoffset)) img.paste(im2, (xoffset + width, yoffset)) img.paste(im3, (xoffset + width + width, yoffset)) return imgdef imshow(img, title=''):#展示图片 fig = plt.gcf() fig.canvas.set_window_title(title) plt.axis('off') plt.imshow(img, interpolation='none') plt.show()import numpy as npimport tensorflow as tfCOLORSPACE_RGB = 'RGB'COLORSPACE_LAB = 'LAB' #RGB与LAB格式互换def preprocess(img, colorspace_in, colorspace_out): if colorspace_out.upper() == COLORSPACE_RGB: if colorspace_in == COLORSPACE_LAB: img = lab_to_rgb(img) # [0, 1] => [-1, 1] img = (img / 255.0) * 2 - 1 elif colorspace_out.upper() == COLORSPACE_LAB: if colorspace_in == COLORSPACE_RGB: img = rgb_to_lab(img / 255.0) L_chan, a_chan, b_chan = tf.unstack(img, axis=3) # L: [0, 100] => [-1, 1] # A, B: [-110, 110] => [-1, 1] img = tf.stack([L_chan / 50 - 1, a_chan / 110, b_chan / 110], axis=3) return imgdef postprocess(img, colorspace_in, colorspace_out): if colorspace_in.upper() == COLORSPACE_RGB: # [-1, 1] => [0, 1] img = (img + 1) / 2 if colorspace_out == COLORSPACE_LAB: img = rgb_to_lab(img) elif colorspace_in.upper() == COLORSPACE_LAB: L_chan, a_chan, b_chan = tf.unstack(img, axis=3) # L: [-1, 1] => [0, 100] # A, B: [-1, 1] => [-110, 110] img = tf.stack([(L_chan + 1) / 2 * 100, a_chan * 110, b_chan * 110], axis=3) if colorspace_out == COLORSPACE_RGB: img = lab_to_rgb(img) return imgdef rgb_to_lab(srgb): # based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c with tf.name_scope("rgb_to_lab"): srgb_pixels = tf.reshape(srgb, [-1, 3]) srgb_pixels = tf.cast(srgb_pixels, tf.float32) with tf.name_scope("srgb_to_xyz"): linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32) exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32) rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask rgb_to_xyz = tf.constant([ # X Y Z [0.412453, 0.212671, 0.019334], # R [0.357580, 0.715160, 0.119193], # G [0.180423, 0.072169, 0.950227], # B ]) xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz) # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions with tf.name_scope("xyz_to_cielab"): # normalize for D65 white point xyz_normalized_pixels = tf.multiply(xyz_pixels, [1 / 0.950456, 1.0, 1 / 1.088754]) epsilon = 6 / 29 linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32) exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32) fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4 / 29) * linear_mask + (xyz_normalized_pixels ** (1 / 3)) * exponential_mask # convert to lab fxfyfz_to_lab = tf.constant([ # l a b [0.0, 500.0, 0.0], # fx [116.0, -500.0, 200.0], # fy [0.0, 0.0, -200.0], # fz ]) lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0]) return tf.reshape(lab_pixels, tf.shape(srgb))def lab_to_rgb(lab): with tf.name_scope("lab_to_rgb"): lab_pixels = tf.reshape(lab, [-1, 3]) lab_pixels = tf.cast(lab_pixels, tf.float32) # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions with tf.name_scope("cielab_to_xyz"): # convert to fxfyfz lab_to_fxfyfz = tf.constant([ # fx fy fz [1 / 116.0, 1 / 116.0, 1 / 116.0], # l [1 / 500.0, 0.0, 0.0], # a [0.0, 0.0, -1 / 200.0], # b ]) fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz) # convert to xyz epsilon = 6 / 29 linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32) exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32) xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4 / 29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask # denormalize for D65 white point xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754]) with tf.name_scope("xyz_to_srgb"): xyz_to_rgb = tf.constant([ # r g b [3.2404542, -0.9692660, 0.0556434], # x [-1.5371385, 1.8760108, -0.2040259], # y [-0.4985314, 0.0415560, 1.0572252], # z ]) rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb) # avoid a slightly negative number messing up the conversion rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0) linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32) exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32) srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1 / 2.4) * 1.055) - 0.055) * exponential_mask return tf.reshape(srgb_pixels, tf.shape(lab))

这些代码代码较长,其中实现了RGB与LAB格式互换等相关辅助功能。由于这些代码其实现逻辑与算法主逻辑并无直接关联,因此读者对其有大概了解即可,无需投入过多精力。接下来实现数据集的加载,此次网络训练使用的数据集是place365,在随书目录附带了相应数据集的压缩文件,使用下面代码实现数据的解压和加载:

!tar -xvf '/content/drive/Shared drives/chenyi19820904.edu.us/place365_dataset/test-256.tar' #解压数据集

使用如上命令将数据集解压到硬盘上后,使用如下代码实现数据加载:

import osimport globimport numpy as npimport tensorflow as tffrom scipy.misc import imreadfrom abc import abstractmethodPLACES365_DATASET = 'places365'class BaseDataset(): #将图片数据依次读入内存以便于用于训练网络 def __init__(self, name, path, training=True, augment=True): self.name = name self.augment = augment and training self.training = training self.path = path self._data = [] def __len__(self): return len(self.data) def __iter__(self): total = len(self) start = 0 while start < total: item = self[start] start += 1 yield item raise StopIteration def __getitem__(self, index): val = self.data[index] try: img = imread(val) if isinstance(val, str) else val # grayscale images if np.sum(img[:,:,0] - img[:,:,1]) == 0 and np.sum(img[:,:,0] - img[:,:,2]) == 0: return None if self.augment and np.random.binomial(1, 0.5) == 1: img = img[:, ::-1, :] except: img = None return img def generator(self, batch_size, recusrive=False): start = 0 total = len(self) while True: while start < total: end = np.min([start + batch_size, total]) items = [] for ix in range(start, end): item = self[ix] if item is not None: items.append(item) start = end yield items if recusrive: start = 0 else: raise StopIteration @property def data(self): if len(self._data) == 0: self._data = self.load() np.random.shuffle(self._data) return self._data @abstractmethod def load(self): return []class Places365Dataset(BaseDataset): def __init__(self, path, training=True, augment=True): super(Places365Dataset, self).__init__(PLACES365_DATASET, path, training, augment) def load(self): #加载图片数据 data = glob.glob(self.path + '/*.jpg', recursive=True) return data

准备好了数据和网络之和,使用如下代码驱动训练流程的进行:

class Generator(tf.keras.Model): def __init__(self, encoder_kernel, decoder_kernel): super(Generator, self).__init__() self.encoder_kernels = encoder_kernel#对应卷积层参数 self.decoder_kernels = decoder_kernel #对应反卷积层参数 self.kernel_size = 4 self.output_channels = 3#最终输出RGB颜色图像图像 self.left_size_layers = [] self.right_size_layers = [] self.last_layers = [] self.create_network() def create_network(self): #构建生成者网络 for index, kernel in enumerate(self.encoder_kernels): #设立卷积层识别输入图像规律 down_sample_layers = [] down_sample_layers.append(tf.keras.layers.Conv2D( kernel_size = self.kernel_size, filters = kernel[0], strides = kernel[1], padding = 'same' )) down_sample_layers.append(tf.keras.layers.BatchNormalization()) down_sample_layers.append(tf.keras.layers.LeakyReLU()) self.left_size_layers.append(down_sample_layers) for index, kernel in enumerate(self.decoder_kernels):#设立反卷积层,实现像素点颜色赋值 up_sample_layers = [] up_sample_layers.append(tf.keras.layers.Conv2DTranspose( kernel_size = self.kernel_size, filters = kernel[0], strides = kernel[1], padding = 'same' )) up_sample_layers.append(tf.keras.layers.BatchNormalization()) up_sample_layers.append(tf.keras.layers.ReLU()) self.right_size_layers.append(up_sample_layers) self.last_layers.append(tf.keras.layers.Conv2D( kernel_size = 1, filters = self.output_channels, strides = 1, padding = 'same', activation = 'tanh' ))#生成彩色图像 def call(self, x): x = tf.convert_to_tensor(x, dtype = tf.float32) left_layer_results = [] for layers in self.left_size_layers: for layer in layers: x = layer(x) left_layer_results.append(x) left_layer_results.reverse() idx = 0 x = left_layer_results[idx] for layers in self.right_size_layers: conresponding_left = left_layer_results[idx + 1] #将对应的卷积层输出直接提交给对应的反卷积层 idx += 1 for layer in layers: x = layer(x) x = tf.keras.layers.concatenate([x, conresponding_left]) for layers in self.last_layers: x = layers(x) return x def create_variables(self):#构造网络参数 dummy1 = np.zeros((1, 256, 256, 1)) self.call(dummy1)class Discriminator(tf.keras.Model): def __init__(self, encoder_kernel): super(Discriminator, self).__init__() self.kernels = encoder_kernel #鉴别者网络卷积层参数 self.discriminator_layers = [] self.kernel_size = 4 self.create_network() def create_network(self): for index, kernel in enumerate(self.kernels):#构造卷积层识别输入图像规律 self.discriminator_layers.append(tf.keras.layers.Conv2D( kernel_size = self.kernel_size, filters = kernel[0], strides = kernel[1], padding = 'same' )) self.discriminator_layers.append(tf.keras.layers.BatchNormalization()) self.discriminator_layers.append(tf.keras.layers.LeakyReLU()) self.discriminator_layers.append(tf.keras.layers.Conv2D( kernel_size = 4, filters = 1, strides = 1, padding = 'same' ))#输出表明着色正确性的数值 def call(self, x): x = tf.convert_to_tensor(x, dtype = tf.float32) for layer in self.discriminator_layers: x = layer(x) return x def create_variables(self):#生成鉴别者网络内部参数 dummy1 = np.zeros((1, 256, 256, 1)) dummy2 = np.zeros((1, 256, 256, 3)) x = np.concatenate((dummy1, dummy2), axis = 3) self.call(x) ColorGAN: def __init__(self): self.generator = None self.discriminator = None self.global_step = tf.Variable(0, dtype = tf.float32, trainable=False) self.create_generator_discriminator() self.data_generator = self.create_dataset(True)#加载训练数据集 self.dataset_val = self.create_dataset(False) self.sample_generator = self.dataset_val.generator(8, True) self.learning_rate = tf.compat.v1.train.exponential_decay( learning_rate = 3e-4, global_step = self.global_step, decay_steps = 1e-5, decay_rate = 0.1 ) #生成者网络训练时需要学习率不断变化 self.generator_optimizer = tf.optimizers.Adam(self.learning_rate, beta_1 = 0) self.discriminator_optimizer = tf.optimizers.Adam(3e-5, beta_1 = 0) self.batch_size = 16 self.epochs = 5 self.epoch = 0 self.step = 0 self.run_folder = "/content/drive/My Drive/ColorGAN/models/" #self.load_model() #反注释该语句可实现网络参数直接加载 def create_generator_discriminator(self): #构造生成者和鉴别者 generator_encoder = [ #第一个数值对应filter,第二个参数对应stride,kernel大小始终保持4 (64, 1), (64, 2), (128, 2), (256, 2), (512, 2), (512, 2), (512, 2), (512, 2) ] #生成者网络卷积层参数 generator_decoder = [ (512, 2), (512, 2), (512, 2), (256, 2), (128, 2), (64, 2), (64, 2) ]#生成者网络反卷积层参数 self.generator = Generator(generator_encoder, generator_decoder) self.generator.create_variables() discriminator_decoder = [ (64, 2), (128, 2), (256, 2), (512, 1) ]#鉴别者网络卷积层参数 self.discriminator = Discriminator(discriminator_decoder) self.discriminator.create_variables() def train(self): for epoch in range(self.epochs): #加载训练数据训练生成者和鉴别者网络 data_gen = self.data_generator.generator(16) for img in data_gen: img = np.array(img) self.train_discriminator(img) self.train_generator(img) self.train_generator(img) self.step += 1 if self.step % 100 == 0: #显示训练效果 display.clear_output(wait = True) self.sample() self.save_model() def train_discriminator(self, img_color): img_gray = tf.image.rgb_to_grayscale(img_color) #将图片转换为灰度图 img_gray = tf.cast(img_gray, tf.float32) lab_color_img = preprocess(img_color, colorspace_in = COLORSPACE_RGB, colorspace_out = COLORSPACE_LAB)#转换为LAB格式 gen_img = self.generator(img_gray) real_img = tf.concat([img_gray, lab_color_img], 3) fake_img = tf.concat([img_gray, gen_img], 3) with tf.GradientTape(watch_accessed_variables=False) as tape: #如果输入数据是灰度图和原色彩图则让输出数值尽可能大,如果是生成者网络构造的色彩图则输出尽可能小 tape.watch(self.discriminator.trainable_variables) discrinimator_real = self.discriminator(real_img, training = True) discriminator_fake = self.discriminator(fake_img, training = True) loss_real = tf.nn.sigmoid_cross_entropy_with_logits(logits = discrinimator_real, labels = tf.ones_like(discrinimator_real)) loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = discriminator_fake, labels = tf.zeros_like(discriminator_fake)) discriminator_loss = tf.reduce_mean(tf.reduce_mean(loss_real) + tf.reduce_mean(loss_fake)) grads = tape.gradient(discriminator_loss, self.discriminator.trainable_variables) self.discriminator_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_variables)) def train_generator(self, img_color): img_gray = tf.image.rgb_to_grayscale(img_color) img_gray = tf.cast(img_gray, tf.float32) lab_color_img = preprocess(img_color, colorspace_in = COLORSPACE_RGB, colorspace_out = COLORSPACE_LAB)#转换为LAB格式 with tf.GradientTape(watch_accessed_variables=False) as tape: #让生成的彩色图和灰度图输入鉴别者网络后所得结果尽可能大 tape.watch(self.generator.trainable_variables) gen_img = self.generator(tf.cast(img_gray, tf.float32), training = True) fake_img = tf.concat([img_gray, gen_img], 3) tape.watch(self.generator.trainable_variables) discriminator_fake = self.discriminator(fake_img, training = True) loss_fake = tf.nn.sigmoid_cross_entropy_with_logits(logits = discriminator_fake, labels = tf.ones_like(discriminator_fake) ) #尽可能通过鉴别者网络的审查 generator_discriminator_loss = tf.reduce_mean(loss_fake) generator_content_loss = tf.reduce_mean(tf.abs(lab_color_img - gen_img)) * 100.0 #保证生成图片物体与输入图片物体尽可能在形状上相同 generator_loss = generator_discriminator_loss + generator_content_loss grads = tape.gradient(generator_loss, self.generator.trainable_variables) self.generator_optimizer.apply_gradients(zip(grads, self.generator.trainable_variables)) def sample(self): #检验genertor的上色效果 input_imgs = next(self.sample_generator) gray_imgs = tf.image.rgb_to_grayscale(input_imgs) gray_imgs = tf.cast(gray_imgs, tf.float32) fake_imgs = self.generator(gray_imgs, training = True) fake_imgs = postprocess(tf.convert_to_tensor(fake_imgs), colorspace_in = COLORSPACE_LAB, colorspace_out = COLORSPACE_RGB) img_show = stitch_images(gray_imgs, input_imgs, fake_imgs.numpy()) #将三张图片贴在一起 imshow(np.array(img_show), "color_gan") def save_model(self): #保存当前网络参数 self.discriminator.save_weights(self.run_folder + "discriminator.h5") self.generator.save_weights(self.run_folder + "generator.h5") def load_model(self):#加载网络参数 self.discriminator.load_weights(self.run_folder + "discriminator.h5") self.generator.load_weights(self.run_folder + "generator.h5") def create_dataset(self, training): #创建训练数据集 return Places365Dataset( path= '/content/test_256/', training=training, augment= True)import osfrom IPython import displaygan = ColorGAN()gan.train()#启动训练流程

运行代码后就能启动网络训练流程。该训练流程较为耗时,读者可以从随书目录中加载笔者已经训练好的网络以便直接查看训练结果。经过长时间训练后,笔者在体验上色效果时发现一个有趣现象,那就是有时上色后的图片比原来彩色图片具有更好的美感或艺术效果,以下是训练后网络实现的上色效果,如图4:

图4 网络上色效果

由于印刷原因,读者可能看不出图4所展示的上色效果。图4的布局是将灰度图,原本彩色图,以及网络上色效果图结合在一起进行对比展示。读者可以通过随书目录获得该图的以便体验网络上色的效果。

0 阅读:0

看看科技观察芯情

简介:感谢大家的关注