變分自編碼器VAE目前存在哪些問題?
代碼詳解:一文讀懂自動編碼器的前世今生,希望對你有幫助~
全文共5718字,預計學習時長20分鐘或更長變分自動編碼器(VAE)可以說是最實用的自動編碼器,但是在討論VAE之前,還必須了解一下用于數據壓縮或去噪的傳統自動編碼器。
變分自動編碼器的厲害之處假設你正在開發一款開放性世界端游,且游戲里的景觀設定相當復雜。
你聘用了一個圖形設計團隊來制作一些植物和樹木以裝飾游戲世界,但是將這些裝飾植物放進游戲中之后,你發現它們看起來很不自然,因為同種植物的外觀看起來一模一樣,這時你該怎么辦呢?
首先,你可能會建議使用一些參數化來嘗試隨機地改變圖像,但是多少改變才足夠呢?又需要多大的改變呢?還有一個重要的問題:實現這種改變的計算強度如何?
這是使用變分自動編碼器的理想情況。我們可以訓練一個神經網絡,使其學習植物的潛在特征,每當我們將一個植物放入游戲世界中,就可以從“已學習”的特征中隨機抽取一個樣本,生成獨特的植物。事實上,很多開放性世界游戲正在通過這種方法構建他們的游戲世界設定。
再看一個更圖形化的例子。假設我們是一個建筑師,想要為任意形狀的建筑生成平面圖??梢宰屢粋€自動編碼器網絡基于任意建筑形狀來學習數據生成分布,它將從數據生成分布中提取樣本來生成一個平面圖。詳見下方的動畫。
對于設計師來說,這些技術的潛力無疑是最突出的。
再假設我們為一個時裝公司工作,需要設計一種新的服裝風格,可以基于“時尚”的服裝來訓練自動編碼器,使其學習時裝的數據生成分布。隨后,從這個低維潛在分布中提取樣本,并以此來創造新的風格。
在該節中我們將研究fashion MNIST數據集。
自動編碼器傳統自動編碼器
自動編碼器其實就是非常簡單的神經結構。它們大體上是一種壓縮形式,類似于使用MP3壓縮音頻文件或使用jpeg壓縮圖像文件。
自動編碼器與主成分分析(PCA)密切相關。事實上,如果自動編碼器使用的激活函數在每一層中都是線性的,那么瓶頸處存在的潛在變量(網絡中最小的層,即代碼)將直接對應(PCA/主成分分析)的主要組件。通常,自動編碼器中使用的激活函數是非線性的,典型的激活函數是ReLU(整流線性函數)和sigmoid/S函數。
網絡背后的數學原理理解起來相對容易。從本質上看,可以把網絡分成兩個部分:編碼器和解碼器。
編碼器函數用?表示,該函數將原始數據X映射到潛在空間F中(潛在空間F位于瓶頸處)。解碼器函數用ψ表示,該函數將瓶頸處的潛在空間F映射到輸出函數。此處的輸出函數與輸入函數相同。因此,我們基本上是在一些概括的非線性壓縮之后重建原始圖像。
編碼網絡可以用激活函數傳遞的標準神經網絡函數表示,其中z是潛在維度。
相似地,解碼網絡可以用相同的方式表示,但需要使用不同的權重、偏差和潛在的激活函數。
隨后就可以利用這些網絡函數來編寫損失函數,我們會利用這個損失函數通過標準的反向傳播程序來訓練神經網絡。
由于輸入和輸出的是相同的圖像,神經網絡的訓練過程并不是監督學習或無監督學習,我們通常將這個過程稱為自我監督學習。自動編碼器的目的是選擇編碼器和解碼器函數,這樣就可以用最少的信息來編碼圖像,使其可以在另一側重新生成。
如果在瓶頸層中使用的節點太少,重新創建圖像的能力將受到限制,導致重新生成的圖像模糊或者和原圖像差別很大。如果使用的節點太多,那么就沒必要壓縮了。
壓縮背后的理論其實很簡單,例如,每當你在Netflix下載某些內容時,發送給你的數據都會被壓縮。一旦這個內容傳輸到電腦上就會通解壓算法在電腦屏幕顯示出來。這類似于zip文件的運行方式,只是這里說的壓縮是在后臺通過流處理算法完成的。
去噪自動編碼器
有幾種其它類型的自動編碼器。其中最常用的是去噪自動編碼器,本教程稍后會和Keras一起進行分析。這些自動編碼器在訓練前給數據添加一些白噪聲,但在訓練時會將誤差與原始圖像進行比較。這就使得網絡不會過度擬合圖像中出現的任意噪聲。稍后,將使用它來清除文檔掃描圖像中的折痕和暗黑區域。
稀疏自動編碼器
與其字義相反的是,稀疏自動編碼器具有比輸入或輸出維度更大的潛在維度。然而,每次網絡運行時,只有很小一部分神經元會觸發,這意味著網絡本質上是“稀疏”的。稀疏自動編碼器也是通過一種規則化的形式來減少網絡過度擬合的傾向,這一點與去噪自動編碼器相似。
收縮自動編碼器
收縮編碼器與前兩個自動編碼器的運行過程基本相同,但是在收縮自動編碼器中,我們不改變結構,只是在丟失函數中添加一個正則化器。這可以被看作是嶺回歸的一種神經形式。
現在了解了自動編碼器是如何運行的,接下來看看自動編碼器的弱項。一些最顯著的挑戰包括:
· 潛在空間中的間隙
· 潛在空間中的可分性
· 離散潛在空間
這些問題都在以下圖中體現。
MNIST數據集的潛在空間表示
這張圖顯示了潛在空間中不同標記數字的位置??梢钥吹綕撛诳臻g中存在間隙,我們不知道字符在這些空間中是長什么樣的。這相當于在監督學習中缺乏數據,因為網絡并沒有針對這些潛在空間的情況進行過訓練。另一個問題就是空間的可分性,上圖中有幾個數字被很好地分離,但也有一些區域被標簽字符是隨機分布的,這讓我們很難區分字符的獨特特征(在這個圖中就是數字0-9)。還有一個問題是無法研究連續的潛在空間。例如,我們沒有針對任意輸入而訓練的統計模型(即使我們填補了潛在空間中的所有間隙也無法做到)。
這些傳統自動編碼器的問題意味著我們還要做出更多努力來學習數據生成分布并生成新的數據與圖像。
現在已經了解了傳統自動編碼器是如何運行的,接下來討論變分自動編碼器。變分自動編碼器采用了一種從貝葉斯統計中提取的變分推理形式,因此會比前幾種自動編碼器稍微復雜一些。我們會在下一節中更深入地討論變分自動編碼器。
變分自動編碼器
變分自動編碼器延續了傳統自動編碼器的結構,并利用這一結構來學習數據生成分布,這讓我們可以從潛在空間中隨機抽取樣本。然后,可以使用解碼器網絡對這些隨機樣本進行解碼,以生成獨特的圖像,這些圖像與網絡所訓練的圖像具有相似的特征。
對于熟悉貝葉斯統計的人來說,編碼器正在學習后驗分布的近似值。這種分布通常很難分析,因為它沒有封閉式的解。這意味著我們要么執行計算上復雜的采樣程序,如馬爾可夫鏈蒙特卡羅(MCMC)算法,要么采用變分方法。正如你可能猜測的那樣,變分自動編碼器使用變分推理來生成其后驗分布的近似值。
我們將會用適量的細節來討論這一過程,但是如果你想了解更深入的分析,建議你閱覽一下Jaan Altosaar撰寫的博客。變分推理是研究生機器學習課程或統計學課程的一個主題,但是了解其基本概念并不需要擁有一個統計學學位。
若對背后的數學理論不感興趣,也可以選擇跳過這篇變分自動編碼器(VAE)編碼教程。
首先需要理解的是后驗分布以及它無法被計算的原因。先看看下面的方程式:貝葉斯定理。這里的前提是要知道如何從潛變量“z”生成數據“x”。這意味著要搞清p(z|x)。然而,該分布值是未知的,不過這并不重要,因為貝葉斯定理可以重新表達這個概率。但是這還沒有解決所有的問題,因為分母(證據)通常很難解。但也不是就此束手無辭了,還有一個挺有意思的辦法可以近似這個后驗分布值。那就是將這個推理問題轉化為一個優化問題。
要近似后驗分布值,就必須找出一個辦法來評估提議分布與真實后驗分布相比是否更好。而要這么做,就需要貝葉斯統計員的最佳伙伴:KL散度。KL散度是兩個概率分布相似度的度量。如果它們相等,那散度為零;而如果散度是正值,就代表這兩個分布不相等。KL散度的值為非負數,但實際上它不是一個距離,因為該函數不具有對稱性??梢圆捎孟旅娴姆绞绞褂肒L散度:
這個方程式看起來可能有點復雜,但是概念相對簡單。那就是先猜測可能生成數據的方式,并提出一系列潛在分布Q,然后再找出最佳分布q*,從將提議分布和真實分布的距離最小化,然后因其難解性將其近似。但這個公式還是有一個問題,那就是p(z|x)的未知值,所以也無法計算KL散度。那么,應該怎么解決這個問題呢?
這里就需要一些內行知識了??梢韵冗M行一些計算上的修改并針對證據下界(ELBO)和p(x)重寫KL散度:
有趣的是ELBO是這個方程中唯一取決于所選分布的變量。而后者由于不取決于q,則不受所選分布的影響。因此,可以在上述方程中通過將ELBO(負值)最大化來使KL散度最小化。這里的重點是ELBO可以被計算,也就是說現在可以進行一個優化流程。
所以現在要做的就是給Q做一個好的選擇,再微分ELBO,將其設為零,然后就大功告成了。可是開始的時候就會面臨一些障礙,即必須選擇最好的分布系列。
一般來說,為了簡化定義q的過程,會進行平均場變分推理。每個變分參數實質上是相互獨立的。因此,每個數據點都有一個單獨的q,可被相稱以得到一個聯合概率,從而獲得一個“平均場”q。
實際上,可以選用任意多的場或者集群。比如在MINIST數據集中,可以選擇10個集群,因為可能有10個數字存在。
要做的第二件事通常被稱為再參數化技巧,通過把隨機變量帶離導數完成,因為從隨機變量求導數的話會由于它的內在隨機性而產生較大的誤差。
再參數化技巧較為深奧,但簡單來說就是可以將一個正態分布寫成均值加標準差,再乘以誤差。這樣在微分時,我們不是從隨機變量本身求導數,而是從它的參數求得。
這個程序沒有一個通用的閉型解,所以近似后驗分布的能力仍然受到一定限制。然而,指數分布族確實有一個閉型解。這意味著標準分布,如正態分布、二項分布、泊松分布、貝塔分布等。所以,就算真正的后驗分布值無法被查出,依然可以利用指數分布族得出最接近的近似值。
變分推理的奧秘在于選擇分布區Q,使其足夠大以求得后驗分布的近似值,但又不需要很長時間來計算。
既然已經大致了解如何訓練網絡學習數據的潛在分布,那么現在可以探討如何使用這個分布生成數據。
數據生成過程
觀察下圖,可以看出對數據生成過程的近似認為應生成數字‘2’,所以它從潛在變量質心生成數值2。但是也許不希望每次都生成一摸一樣的數字‘2’,就好像上述端游例子所提的植物,所以我們根據一個隨機數和“已學”的數值‘2’分布范圍,在潛在空間給這一過程添加了一些隨機噪聲。該過程通過解碼器網絡后,我們得到了一個和原型看起來不一樣的‘2’。
這是一個非常簡化的例子,抽象描述了實際自動編碼器網絡的體系結構。下圖表示了一個真實變分自動編碼器在其編碼器和解碼器網絡使用卷積層的結構體系。從這里可以觀察到,我們正在分別學習潛在空間中生成數據分布的中心和范圍,然后從這些分布“抽樣”生成本質上“虛假”的數據。
該學習過程的固有性代表所有看起來很相似的參數(刺激相同的網絡神經元放電)都聚集到潛在空間中,而不是隨意的分散。如下圖所示,可以看到數值2都聚集在一起,而數值3都逐漸地被推開。這一過程很有幫助,因為這代表網絡并不會在潛在空間隨意擺放字符,從而使數值之間的轉換更有真實性。
整個網絡體系結構的概述如下圖所示。希望讀者看到這里,可以比較清晰地理解整個過程。我們使用一組圖像訓練自動編碼器,讓它學習潛在空間里均值和標準值的差,從而形成我們的數據生成分布。接下來,當我們要生成一個類似的圖像,就從潛在空間的一個質心取樣,利用標準差和一些隨機誤差對它進行輕微的改變,然后使其通過解碼器網絡。從這個例子可以明顯看出,最終的輸出看起來與輸入圖像相似,但卻是不一樣的。
變分自動編碼器編碼指南
本節將討論一個簡單的去噪自動編碼器,用于去除文檔掃描圖像上的折痕和污痕,以及去除Fashion MNIST數據集中的噪聲。然后,在MNIST數據集訓練網絡后,就使用變分自動編碼器生成新的服裝。
去噪自編碼器
Fashion MNIST
在第一個練習中,在Fashion MNIST數據集添加一些隨機噪聲(椒鹽噪聲),然后使用去噪自編碼器嘗試移除噪聲。首先進行預處理:下載數據,調整數據大小,然后添加噪聲。
## Download the data(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()## normalize and reshapex_train = x_train/255.x_test = x_test/255.x_train = x_train.reshape(-1, 28, 28, 1)x_test = x_test.reshape(-1, 28, 28, 1)# Lets add sample noise - Salt and Peppernoise = augmenters.SaltAndPepper(0.1)seq_object = augmenters.Sequential([noise])train_x_n = seq_object.augment_images(x_train * 255) / 255val_x_n = seq_object.augment_images(x_test * 255) / 255接著,給自編碼器網絡創建結構。這包括多層卷積神經網絡、編碼器網絡的最大池化層和解碼器網絡上的升級層。
# input layerinput_layer =Input(shape=(28, 28, 1)) # encodingarchitectureencoded_layer1= Conv2D(64, (3, 3), activation='relu', padding='same')(input_layer)encoded_layer1= MaxPool2D( (2, 2), padding='same')(encoded_layer1)encoded_layer2= Conv2D(32, (3, 3), activation='relu', padding='same')(encoded_layer1)encoded_layer2= MaxPool2D( (2, 2), padding='same')(encoded_layer2)encoded_layer3= Conv2D(16, (3, 3), activation='relu', padding='same')(encoded_layer2)latent_view = MaxPool2D( (2, 2),padding='same')(encoded_layer3) # decodingarchitecturedecoded_layer1= Conv2D(16, (3, 3), activation='relu', padding='same')(latent_view)decoded_layer1= UpSampling2D((2, 2))(decoded_layer1)decoded_layer2= Conv2D(32, (3, 3), activation='relu', padding='same')(decoded_layer1)decoded_layer2= UpSampling2D((2, 2))(decoded_layer2)decoded_layer3= Conv2D(64, (3, 3), activation='relu')(decoded_layer2)decoded_layer3= UpSampling2D((2, 2))(decoded_layer3)output_layer = Conv2D(1, (3, 3), padding='same',activation='sigmoid')(decoded_layer3) # compile themodelmodel =Model(input_layer, output_layer)model.compile(optimizer='adam',loss='mse') # run themodelearly_stopping= EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=5,mode='auto')history =model.fit(train_x_n, x_train, epochs=20, batch_size=2048,validation_data=(val_x_n, x_test), callbacks=[early_stopping])所輸入的圖像,添加噪聲的圖像,和輸出圖像。
從時尚MNIST輸入的圖像。
添加椒鹽噪聲的輸入圖像。
從去噪網絡輸出的圖像。
從這里可以看到,我們成功從噪聲圖像去除相當的噪聲,但同時也失去了一定量的服裝細節的分辨率。這是使用穩健網絡所需付出的代價之一。可以對該網絡進行調優,使最終的輸出更能代表所輸入的圖像。
文本清理
去噪自編碼器的第二個例子包括清理掃描圖像的折痕和暗黑區域。這是最終獲得的輸入和輸出圖像。
輸入的有噪聲文本數據圖像。
經清理的文本圖像。
為此進行的數據預處理稍微復雜一些,因此就不在這里進行介紹,預處理過程和相關數據可在GitHub庫里獲取。網絡結構如下:
input_layer= Input(shape=(258, 540, 1)) #encoderencoder= Conv2D(64, (3, 3), activation='relu', padding='same')(input_layer)encoder= MaxPooling2D((2, 2), padding='same')(encoder) #decoderdecoder= Conv2D(64, (3, 3), activation='relu', padding='same')(encoder)decoder= UpSampling2D((2, 2))(decoder)output_layer= Conv2D(1, (3, 3), activation='sigmoid', padding='same')(decoder) ae =Model(input_layer, output_layer) ae.compile(loss='mse',optimizer=Adam(lr=0.001)) batch_size= 16epochs= 200 early_stopping= EarlyStopping(monitor='val_loss',min_delta=0,patience=5,verbose=1,mode='auto')history= ae.fit(x_train, y_train, batch_size=batch_size, epochs=epochs,validation_data=(x_val, y_val), callbacks=[early_stopping])變分自編碼器
最后的壓軸戲,是嘗試從FashionMNIST數據集現有的服裝中生成新圖像。
其中的神經結構較為復雜,包含了一個稱‘Lambda’層的采樣層。
batch_size = 16latent_dim = 2 # Number of latent dimension parameters# ENCODER ARCHITECTURE: Input -> Conv2D*4 -> Flatten -> Denseinput_img = Input(shape=(28, 28, 1))x = Conv2D(32, 3, padding='same', activation='relu')(input_img)x = Conv2D(64, 3, padding='same', activation='relu', strides=(2, 2))(x)x = Conv2D(64, 3, padding='same', activation='relu')(x)x = Conv2D(64, 3, padding='same', activation='relu')(x)# need to know the shape of the network here for the decodershape_before_flattening = K.int_shape(x)x = Flatten()(x)x = Dense(32, activation='relu')(x)# Two outputs, latent mean and (log)variancez_mu = Dense(latent_dim)(x)z_log_sigma = Dense(latent_dim)(x)## SAMPLING FUNCTIONdef sampling(args): z_mu, z_log_sigma = args epsilon = K.random_normal(shape=(K.shape(z_mu)[0], latent_dim), mean=0., stddev=1.) return z_mu + K.exp(z_log_sigma) * epsilon# sample vector from the latent distributionz = Lambda(sampling)([z_mu, z_log_sigma])## DECODER ARCHITECTURE# decoder takes the latent distribution sample as inputdecoder_input = Input(K.int_shape(z)[1:])# Expand to 784 total pixelsx = Dense(np.prod(shape_before_flattening[1:]), activation='relu')(decoder_input)# reshapex = Reshape(shape_before_flattening[1:])(x)# use Conv2DTranspose to reverse the conv layers from the encoderx = Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2, 2))(x)x = Conv2D(1, 3, padding='same', activation='sigmoid')(x)# decoder model statementdecoder = Model(decoder_input, x)# apply the decoder to the sample from the latent distributionz_decoded = decoder(z)這就是體系結構,但還是需要插入損失函數再合并KL散度。# construct a custom layer to calculate the lossclass CustomVariationalLayer(Layer): def vae_loss(self, x, z_decoded): x = K.flatten(x) z_decoded = K.flatten(z_decoded) # Reconstruction loss xent_loss = binary_crossentropy(x, z_decoded) # KL divergence kl_loss = -5e-4 * K.mean(1 + z_log_sigma - K.square(z_mu) - K.exp(z_log_sigma), axis=-1) return K.mean(xent_loss + kl_loss) # adds the custom loss to the class def call(self, inputs): x = inputs[0] z_decoded = inputs[1] loss = self.vae_loss(x, z_decoded) self.add_loss(loss, inputs=inputs) return x# apply the custom loss to the input images and the decoded latent distribution sampley = CustomVariationalLayer()([input_img, z_decoded])# VAE model statementvae = Model(input_img, y)vae.compile(optimizer='rmsprop', loss=None)vae.fit(x=train_x, y=None, shuffle=True, epochs=20, batch_size=batch_size, validation_data=(val_x, None))現在,可以查看重構的樣本,看看網絡能夠學習到什么。
從這里可以清楚看到鞋子、手袋和服裝之間的過渡。在此并沒有標出所有使畫面更清晰的潛在空間。也可以觀察到Fashion MNIST數據集現有的10件服裝的潛在空間和顏色代碼。
可看出這些服飾分成了不同的集群。
留言 點贊 關注
我們一起分享AI學習與發展的干貨
歡迎關注全平臺AI垂類自媒體 “讀芯術”