Milvus向量库和SeetaSDK工具类分享
- 1.Milvus向量库工具类
- 2.SeetaSDK工具类
1.Milvus向量库工具类
Milvus的Maven依赖:
<dependency><groupId>io.milvus</groupId><artifactId>milvus-sdk-java</artifactId><version>2.1.0</version><exclusions><exclusion><artifactId>log4j-slf4j-impl</artifactId><groupId>org.apache.logging.log4j</groupId></exclusion></exclusions></dependency>
向量库的配置类:
@Data
@Component
@ConfigurationProperties(MilvusConfiguration.PREFIX)
public class MilvusConfiguration {public static final String PREFIX = "milvus-config";public String host;public int port;public String collectionName;}
工具类主类:
@Slf4j
@Component
public class MilvusUtil {@Resourceprivate MilvusConfiguration milvusConfiguration;private MilvusServiceClient milvusServiceClient;@PostConstructprivate void connectToServer() {milvusServiceClient = new MilvusServiceClient(ConnectParam.newBuilder().withHost(milvusConfiguration.host).withPort(milvusConfiguration.port).build());// 加载数据LoadCollectionParam faceSearchNewLoad = LoadCollectionParam.newBuilder().withCollectionName(milvusConfiguration.collectionName).build();R<RpcStatus> rpcStatusR = milvusServiceClient.loadCollection(faceSearchNewLoad);log.info("Milvus LoadCollection [{}]", rpcStatusR.getStatus() == 0 ? "Successful!" : "Failed!");}
}
主类里的数据入库方法:
public int insertDataToMilvus(String id, String path, float[] feature) {List<InsertParam.Field> fields = new ArrayList<>();List<Float> featureList = new ArrayList<>(feature.length);for (float v : feature) {featureList.add(v);}fields.add(new InsertParam.Field("id", Collections.singletonList(id)));fields.add(new InsertParam.Field("image_path", Collections.singletonList(path)));fields.add(new InsertParam.Field("image_feature", Collections.singletonList(featureList)));InsertParam insertParam = InsertParam.newBuilder().withCollectionName(milvusConfiguration.collectionName)//.withPartitionName("novel").withFields(fields).build();R<MutationResult> insert = milvusServiceClient.insert(insertParam);return insert.getStatus();}
主类类的数据查询方法:
- 这里的topK没有进行参数化。
public List<MilvusRes> searchImageByFeatureVector(float[] feature) {List<Float> featureList = new ArrayList<>(feature.length);for (float v : feature) {featureList.add(v);}List<String> queryOutputFields = Arrays.asList("image_path");SearchParam faceSearch = SearchParam.newBuilder().withCollectionName(milvusConfiguration.collectionName).withMetricType(MetricType.IP).withVectorFieldName("image_feature").withVectors(Collections.singletonList(featureList)).withOutFields(queryOutputFields).withRoundDecimal(3).withTopK(10).build();// 执行搜索long l = System.currentTimeMillis();R<SearchResults> respSearch = milvusServiceClient.search(faceSearch);log.info("MilvusServiceClient.search cost [{}]", System.currentTimeMillis() - l);// 解析结果数据SearchResultData results = respSearch.getData().getResults();int scoresCount = results.getScoresCount();SearchResultsWrapper wrapperSearch = new SearchResultsWrapper(results);List<MilvusRes> milvusResList = new ArrayList<>();for (int i = 0; i < scoresCount; i++) {float score = wrapperSearch.getIDScore(0).get(i).getScore();Object imagePath = wrapperSearch.getFieldData("image_path", 0).get(i);MilvusRes milvusRes = MilvusRes.builder().score(score).imagePath(imagePath.toString()).build();milvusResList.add(milvusRes);}return milvusResList;}
2.SeetaSDK工具类
SeetaSDK的Maven依赖:
<dependency><groupId>com.seeta</groupId><artifactId>sdk</artifactId><version>1.2.1</version><scope>system</scope><systemPath>${project.basedir}/lib/seeta-sdk-platform-1.2.1.jar</systemPath></dependency><!--注意--><plugin><groupId>org.springframework.boot</groupId><artifactId>spring-boot-maven-plugin</artifactId><configuration><includeSystemScope>true</includeSystemScope></configuration></plugin>
jar是从官网下的源码进行的打包:
工具类主类:
@Slf4j
@Component
public class FaceUtil {static {// 加载本地方法LoadNativeCore.LOAD_NATIVE(SeetaDevice.SEETA_DEVICE_AUTO);}@Resourceprivate SeetaModelConfiguration seetaModelConfiguration;private FaceDetectorProxy faceDetectorProxy;private FaceLandmarkerProxy faceLandmarkerProxy;private FaceRecognizerProxy faceRecognizerProxy;private AgePredictorProxy agePredictorProxy;private GenderPredictorProxy genderPredictorProxy;private MaskDetectorProxy maskDetectorProxy;private EyeStateDetectorProxy eyeStateDetectorProxy;}
主类里的初始方法:
@PostConstructprivate void inti() {String basePath = seetaModelConfiguration.basePath;try {// 人脸识别检测器对象池配置SeetaConfSetting detectorPoolSetting = new SeetaConfSetting(new SeetaModelSetting(0, new String[]{basePath + seetaModelConfiguration.faceDetectorModelFileName},SeetaDevice.SEETA_DEVICE_AUTO));faceDetectorProxy = new FaceDetectorProxy(detectorPoolSetting);// 关键点定位器【默认使用5点可通过配置切换为68点】SeetaConfSetting faceLandmarkerPoolSetting = new SeetaConfSetting(new SeetaModelSetting(1, new String[]{basePath + seetaModelConfiguration.faceLandmarkerModelFileName},SeetaDevice.SEETA_DEVICE_AUTO));faceLandmarkerProxy = new FaceLandmarkerProxy(faceLandmarkerPoolSetting);// 人脸向量特征提取和对比器SeetaConfSetting faceRecognizerPoolSetting = new SeetaConfSetting(new SeetaModelSetting(2, new String[]{basePath + seetaModelConfiguration.faceRecognizerModelFileName},SeetaDevice.SEETA_DEVICE_AUTO));faceRecognizerProxy = new FaceRecognizerProxy(faceRecognizerPoolSetting);// 年龄评估器SeetaConfSetting agePredictorPoolSetting = new SeetaConfSetting(new SeetaModelSetting(3, new String[]{basePath + seetaModelConfiguration.agePredictorModelFileName},SeetaDevice.SEETA_DEVICE_AUTO));agePredictorProxy = new AgePredictorProxy(agePredictorPoolSetting);// 性别识别器SeetaConfSetting genderPredictorPoolSetting = new SeetaConfSetting(new SeetaModelSetting(4, new String[]{basePath + seetaModelConfiguration.genderPredictorModelFileName},SeetaDevice.SEETA_DEVICE_AUTO));genderPredictorProxy = new GenderPredictorProxy(genderPredictorPoolSetting);// 口罩检测器SeetaConfSetting maskDetectorPoolSetting = new SeetaConfSetting(new SeetaModelSetting(5, new String[]{basePath + seetaModelConfiguration.maskDetectorModelFileName},SeetaDevice.SEETA_DEVICE_AUTO));maskDetectorProxy = new MaskDetectorProxy(maskDetectorPoolSetting);// 眼睛状态检测SeetaConfSetting eyeStaterPoolSetting = new SeetaConfSetting(new SeetaModelSetting(5, new String[]{basePath + seetaModelConfiguration.eyeStateModelFileName},SeetaDevice.SEETA_DEVICE_AUTO));eyeStateDetectorProxy = new EyeStateDetectorProxy(eyeStaterPoolSetting);} catch (Exception e) {e.printStackTrace();}}
主类里的根据图片路径获取脸部特征向量方法:
/*** 根据图片路径获取脸部特征向量** @param imagePath 图片路径* @return 脸部特征向量*/public float[] getFaceFeaturesByPath(String imagePath) {try {// 照片人脸识别SeetaImageData image = SeetafaceUtil.toSeetaImageData(imagePath);SeetaRect[] detects = faceDetectorProxy.detect(image);// 人脸关键点定位【主驾或副驾仅有一个人脸,多个人脸仅取第一个】if (detects.length > 0) {SeetaPointF[] pointFace = faceLandmarkerProxy.mark(image, detects[0]);// 人脸向量特征提取featuresreturn faceRecognizerProxy.extract(image, pointFace);}} catch (Exception e) {e.printStackTrace();}return null;}
主类里的根据人像图片的路径获取其属性【年龄、性别、是否戴口罩、眼睛状态】方法:
/*** 根据人像图片的路径获取其属性【年龄、性别、是否戴口罩、眼睛状态】** @param imagePath 图片路径* @return 图片属性 MAP 对象*/public Map<String, Object> getAttributeByPath(String imagePath) {long l = System.currentTimeMillis();Map<String, Object> attributeMap = new HashMap<>(4);try {// 监测人脸SeetaImageData image = SeetafaceUtil.toSeetaImageData(imagePath);SeetaRect[] detects = faceDetectorProxy.detect(image);if (detects.length > 0) {SeetaPointF[] pointFace = faceLandmarkerProxy.mark(image, detects[0]);// 获取年龄int age = agePredictorProxy.predictAgeWithCrop(image, pointFace);attributeMap.put("age", age);// 性别GenderPredictor.GENDER gender = genderPredictorProxy.predictGenderWithCrop(image, pointFace).getGender();attributeMap.put("gender", gender);// 口罩boolean mask = maskDetectorProxy.detect(image, detects[0]).getMask();attributeMap.put("mask", mask);// 眼睛EyeStateDetector.EYE_STATE[] eyeStates = eyeStateDetectorProxy.detect(image, pointFace);attributeMap.put("eye", Arrays.toString(eyeStates));log.info("getAttributeByPath [{}] cost [{}]", imagePath, System.currentTimeMillis() - l);}} catch (Exception e) {e.printStackTrace();return attributeMap;}return attributeMap;}