量子生成对抗网络

news/2024/11/26 15:49:14/

        生成对抗网络是由两部分神经网络组成,分别为生成器和判别器。量子生成对抗网络的基本原理与经典的GAN基本相同。

        生成对抗网络由GoodFellow等人提出,主要功能是生成伪造的样本。生成器G的输入为一个随机向量z,输出为一个生成样本G(z)。判别器D的输入为一个真实的或生成的样本x,输出为对该样本的判断。当输入真实样本x时,判别器的输出为D(x),接近1。当输入生成样本G(z)时,判别器的输出为D(G(z)),接近0。

GAN的优化目标函数可以表示为:

 

参数化量子线路 

量子生成对抗网络所用的参数化量子线路主要由三部分组成:旋转层、叠加层和纠缠层

旋转层由旋转门Ry(θ) 组成

叠加层使用的是双量子比特门

 其中 旋转层由旋转门R_{y}(\Theta )组成。叠加层使用的是双量子比特门R_{yy}(\Theta )

 

        相比起神经网络,多了一个叠加层,主要是因为量子资源和设备的限制,退相干会导致系统内部的相互作用关系发生变化,影响了系统的纠缠程度,进而影响QGAN中生成器和判别器的功能,而叠加层可以减少这种影响。

        量子判别器和量子生成器都采用前述的参数化量子线路

        QGAN通过经典计算机和量子计算机之间的迭代切换,找到参数化量子线路的最优参数

   QGAN将保真度作为目标函数 ,因此用交换测试得到保真度
   量子判别器之后量子态为|\delta _{1}\rangle,量子生成器之后量子态为|\gamma \rangle,保真度|\langle \gamma|\delta _{1} \rangle|^{2}要尽量接近0
   量子判别器之后量子态为|\delta _{2}\rangle  ,真实样本为|\xi \rangle,保真度|\langle \gamma|\delta _{2} \rangle|^{2}要尽量接近1
   因此 训练量子判别器时要最小化  -log(1-|\langle \gamma|\delta _{1} \rangle|^{2}) -log(|\langle \gamma|\delta _{2} \rangle|^{2})
训练量子生成器时要最小化 -log(|\langle \gamma|\delta _{2} \rangle|^{2})

 

 生成一个带有数字39的图片真实图片由手写数字39组成,特征为2

 

代码

#qiskit版本>1.0.0
import random
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
from qiskit import QuantumRegister, ClassicalRegister,transpile
from qiskit import QuantumCircuit
from qiskit_aer import Aer
from math import pi
from qiskit import *  
import tensorflow as tf
from sklearn.decomposition import PCA
import time
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patheffects as PathEffects#数据准备
test_images,test_labels = tf.keras.datasets.mnist.load_data()
train_images = test_images[0].reshape(60000,784)
train_labels = test_images[1]
labels = test_images[1]
train_images = train_images/255
k=2
pca = PCA(n_components=k)
pca.fit(train_images)
pca_data = pca.transform(train_images)[:10000]
train_labels = train_labels[:10000]
t_pca_data = pca_data.copy()
pca_descaler = [[] for _ in range(k)]
for i in range(k):if pca_data[:,i].min() < 0:pca_descaler[i].append(pca_data[:,i].min())pca_data[:,i] += np.abs(pca_data[:,i].min())else:pca_descaler[i].append(pca_data[:,i].min())pca_data[:,i] -= pca_data[:,i].min()pca_descaler[i].append(pca_data[:,i].max())pca_data[:,i] /= pca_data[:,i].max()
pca_data_rot= 2*np.arcsin(np.sqrt(pca_data))
valid_labels = None
valid_labels = train_labels==9
valid_labels = train_labels == 3 
pca_data_rot = pca_data_rot[valid_labels]
pca_data = pca_data[valid_labels]#利用特征重构图像
def descale_points(d_point,scales=pca_descaler,tfrm=pca):for col in range(d_point.shape[1]):d_point[:,col] *= scales[col][1]d_point[:,col] += scales[col][0]reconstruction = tfrm.inverse_transform(d_point)return reconstruction
learning_rate=0.01#初始化判别器和生成器参数
thetaD=[np.random.rand()*np.pi for i in range(4)]
thetaG=[np.random.rand()*np.pi for i in range(4)]
#定义数组分别保存更新后判别器和生成器的参数
thetaD_new=[0,0,0,0]
thetaG_new=[0,0,0,0]
#保存被加载的真数据
data=[0,0]#定义训练判别器的量子线路
circuit=QuantumCircuit(5,5)
for i in range(1,3):circuit.ry(thetaD[i-1],i)
circuit.ryy(thetaD[2],1,2)
circuit.cry(thetaD[3],1,2)
for i in range(1,3):circuit.ry(thetaG[i-1],i+2)
circuit.ryy(thetaG[2],3,4)
circuit.cry(thetaG[3],3,4)
circuit.barrier(0,1,2,3,4)
circuit.h(0)
for i in range(1,3):circuit.cswap(0,i,i+2)
circuit.h(0)
circuit.measure(0,0)
circuit.draw(output='mpl', plot_barriers=False)#定义训练生成器的量子线路
circuit1=QuantumCircuit(5,5)
for i in range(1,3):circuit1.ry(thetaD[i-1],i)
circuit1.ryy(thetaD[2],1,2)
circuit1.cry(thetaD[3],1,2)
circuit1.ry(data[0],3)
circuit1.ry(data[1],4)
circuit1.barrier(0,1,2,3,4)
circuit1.h(0)
for i in range(1,3):circuit1.cswap(0,i,i+2)
circuit1.h(0)
circuit1.measure(0,0)
circuit1.draw(output='mpl', plot_barriers=False)#定义生成器线路,便于得到训练后生成的数据
circuit2=QuantumCircuit(2,2)
circuit2.ry(thetaG[0],0)
circuit2.ry(thetaG[1],1)
circuit2.ryy(thetaG[2],0,1)
circuit2.cry(thetaG[3],0,1)
circuit2.measure(0,0)
circuit2.measure(1,1)
circuit2.draw(output='mpl')
value1=0
value2=0
f_diff=0
b_diff=0
backend=Aer.get_backend('qasm_simulator')
for epoch in range(0,50):print(f'第--{epoch}--次训练')par_shift = 0.5*np.pi# 通过circuit线路,使用假数据训练判别器for i in range(1):for j in range(0,4):thetaD[j]+=par_shift# Set up the backendbackend = Aer.get_backend('qasm_simulator')transpiled_qc = transpile(circuit, backend)# Run the transpiled circuitjob_sim1 = backend.run(transpiled_qc, shots=2000)sim_result1 = job_sim1.result()measurement_result1 = sim_result1.get_counts(circuit)value1=2*(measurement_result1['00000']/2000-0.5)if value1 <= 0.005:value1 = 0.005f_diff=-np.log(1-value1)thetaD[j]-=2*par_shiftbackend = Aer.get_backend('qasm_simulator')transpiled_qc = transpile(circuit, backend)# Run the transpiled circuitjob_sim1 = backend.run(transpiled_qc, shots=2000)sim_result1 = job_sim1.result()measurement_result1 = sim_result1.get_counts(circuit)value2=2*(measurement_result1['00000']/2000-0.5)if value2 <= 0.005:value2 = 0.005b_diff=-np.log(1-value2)thetaD[j]+=par_shiftdf=0.5*(f_diff-b_diff)if abs(df)>1:df = df/abs(df)thetaD_new[j]=thetaD[j]-learning_rate*df/10thetaD=thetaD_new# 使用circuit1线路训练判别器for index,point in enumerate(pca_data_rot):data[0] = point[0]data[1] = point[1]for j in range(0,4):thetaD[j]+=par_shiftbackend = Aer.get_backend('qasm_simulator')transpiled_qc1 = transpile(circuit1, backend)# Run the transpiled circuitjob_sim = backend.run(transpiled_qc1, shots=1000)sim_result = job_sim.result()measurement_result = sim_result.get_counts(circuit1)value1=2*(measurement_result['00000']/1000-0.5)if value1 <= 0.005:value1 = 0.005f_diff=-np.log(value1)thetaD[j]-=2*par_shiftbackend = Aer.get_backend('qasm_simulator')transpiled_qc1 = transpile(circuit1, backend)# Run the transpiled circuitjob_sim = backend.run(transpiled_qc1, shots=1000)sim_result = job_sim.result()measurement_result = sim_result.get_counts(circuit1)value2=2*(measurement_result['00000']/1000-0.5)if value2 <= 0.005:value2 = 0.005thetaD[j]+=par_shiftb_diff=-np.log(value2)df=0.5*(f_diff-b_diff)thetaD_new[j]=thetaD[j]-learning_rate*df/10thetaD=thetaD_new# 通过circuit线路,利用判别器训练生成器for i in range(len(pca_data_rot)//10):for j in range(0,4):thetaG[j]+=par_shiftbackend = Aer.get_backend('qasm_simulator')transpiled_qc = transpile(circuit, backend)# Run the transpiled circuitjob_sim = backend.run(transpiled_qc, shots=1000)sim_result = job_sim.result()measurement_result = sim_result.get_counts(circuit)value1=2*(measurement_result['00000']/1000-0.5)if value1 <= 0.005:value1 = 0.005f_diff=-np.log(value1)thetaG[j]-=2*par_shiftbackend = Aer.get_backend('qasm_simulator')transpiled_qc = transpile(circuit, backend)# Run the transpiled circuitjob_sim = backend.run(transpiled_qc, shots=1000)sim_result = job_sim.result()measurement_result = sim_result.get_counts(circuit)value2=2*(measurement_result['00000']/1000-0.5)if value2 <= 0.005:value2 = 0.005thetaG[j]+=par_shiftb_diff=-np.log(value2)df=0.5*(f_diff-b_diff)thetaG_new[j]=thetaG[j]-learning_rate*df*5thetaG=thetaG_newdata = []n_results = 2#使用训练好的生成器生成数据并使用逆PCA生成图像for _ in range(16):backend = Aer.get_backend('qasm_simulator')transpiled_qc2 = transpile(circuit2, backend)# Run the transpiled circuitjob = backend.run(transpiled_qc2, shots=1000)results = job.result()measurement_result2 = results.get_counts(circuit2)bins = [[0,0] for _ in range(n_results)]for key,value in measurement_result2.items():for i in range(n_results):if key[-i-1]== '1':bins[i][0] += valuebins[i][1] += valuefor i,pair in enumerate(bins):bins[i]= pair[0]/pair[1]data.append(bins)data = np.array(data)new_info = descale_points(data[:16])new_info = new_info.reshape(new_info.shape[0],28,28)print(f"Epoch {epoch} Generated Images")for i in range(new_info.shape[0]):plt.subplot(4, 4, i+1)plt.imshow(new_info[i, :, :], cmap='gray')plt.axis('off')plt.show()

 


http://www.ppmy.cn/news/1550095.html

相关文章

Git Github Gitlab与Gitee的关系

Git是代码版本管理工具 -------项目通过Git可以切换到任意代码版本 Github和Gitee是基于Git技术构建的远程仓库网站 -------可以将你的代码仓库提交上去保存 GitHub与Gitee的区别 -------前者是国外建立,资源更丰富,后者是国内建立,免费功能更多 Gitlab和Github功能类似 …

Python人工智能项目报告

一、实践概述 1、实践计划和目的 在现代社会&#xff0c;计算机技术已成为支撑社会发展的核心力量&#xff0c;渗透到生活的各个领域&#xff0c;应关注人类福祉&#xff0c;确保自己的工作成果能够造福社会&#xff0c;同时维护安全、健康的自然环境&#xff0c;设计出具有包…

HarmonyOS NEXT应用元服务开发Intents Kit(意图框架服务)习惯推荐方案开发者测试

意图框架向开发者提供真机测试能力&#xff0c;即开发者可连接设备进行调测。开发者完成代码开发之后&#xff0c;功能正式上架应用市场前&#xff0c;可以在HarmonyOS NEXT设备上面进行自验证&#xff0c;打磨体验。真机测试分为三个步骤&#xff1a;基础信息提供&#xff0c;…

学习threejs,使用设置bumpMap凹凸贴图创建褶皱,实现贴图厚度效果

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️THREE.MeshPhongMaterial高…

MySQL之索引与事务

索引 索引的分类 从定义的分类来看&#xff0c;索引分为&#xff1a; 主键索引&#xff1a;必须唯一且不能有null值 唯一索引&#xff1a;必须唯一&#xff0c;但是允许有null值 普通索引&#xff1a;即对一个列添加索引&#xff0c;也称单列索引 联合索引&#xff1a;对多个…

企业办公自动化:Spring Boot OA管理系统开发与实践

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统&#xff0c;它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等&#xff0c;非常…

从零开始打造个人博客:我的网页设计之旅

✅作者简介&#xff1a;2022年博客新星 第八。热爱国学的Java后端开发者&#xff0c;修心和技术同步精进。 &#x1f34e;个人主页&#xff1a;Java Fans的博客 &#x1f34a;个人信条&#xff1a;不迁怒&#xff0c;不贰过。小知识&#xff0c;大智慧。 ✨特色专栏&#xff1a…

算法日记 33 day 动态规划(打家劫舍,股票买卖)

今天来看看动态规划的打家劫舍和买卖股票的问题。 上题目&#xff01;&#xff01;&#xff01;&#xff01; 题目&#xff1a;打家劫舍 198. 打家劫舍 - 力扣&#xff08;LeetCode&#xff09; 你是一个专业的小偷&#xff0c;计划偷窃沿街的房屋。每间房内都藏有一定的现金…