机器学习(5):支持向量机

embedded/2025/3/9 10:35:18/

1 介绍

        支持向量机(Support Vector Machine,简称 SVM)是一种监督学习算法,主要用于分类和回归问题。SVM 的核心思想是找到一个最优的超平面,将不同类别的数据分开。这个超平面不仅要能够正确分类数据,还要使得两个类别之间的间隔(margin)最大化。

1.1 线性可分

        在二维空间上,两类点被一条直线完全分开叫做线性可分。

        样本中距离超平面最近的一些点,这些点叫做支持向量。

1.2 软间隔

        在实际应用中,完全线性可分的样本是很少的,如果遇到了不能够完全线性可分的样本,我们应该怎么办?比如下面这个:

        于是我们就有了软间隔,相比于硬间隔的苛刻条件,我们允许个别样本点出现在间隔带里面,比如:

1.3 线性不可分

       我们刚刚讨论的硬间隔和软间隔都是在说样本的完全线性可分或者大部分样本点的线性可分。但我们可能会碰到的一种情况是样本点不是线性可分的,比如:

        这种情况的解决方法就是将二维线性不可分样本映射到高维空间中,让样本点在高维空间线性可分,比如:

        对于在有限维度向量空间中线性不可分的样本,我们将其映射到更高维度的向量空间里,再通过间隔最大化的方式,学习得到支持向量机,就是非线性 SVM。

1.4 优缺点

        优点

  • 有严格的数学理论支持,可解释性强,不依靠统计方法,从而简化了通常的分类和回归问题
  • 能找出对任务至关重要的关键样本(即:支持向量)
  • 采用核技巧之后,可以处理非线性分类/回归任务
  • 最终决策函数只由少数的支持向量所确定,计算的复杂性取决于支持向量的数目,而不是样本空间的维数,这在某种意义上避免了“维数灾难”。

        缺点

  • 训练时间长。当采用 SMO 算法时,由于每次都需要挑选一对参数,因此时间复杂度为 O(N2) ,其中 N 为训练样本的数量;
  • 当采用核技巧时,如果需要存储核矩阵,则空间复杂度为 O(N2) ;
  • 模型预测时,预测时间与支持向量的个数成正比。当支持向量的数量较大时,预测计算复杂度较高。

        因此支持向量机目前只适合小批量样本的任务,无法适应百万甚至上亿样本的任务。

2 使用 Python 实现 SVM

2.1 安装必要的库

        首先,确保你已经安装了scikit-learn库。如果没有安装,可以使用以下命令进行安装:

pip install scikit-learn

2.2 导入库

import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

2.3 加载数据集

        我们将使用scikit-learn自带的鸢尾花(Iris)数据集。

# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data[:, :2]  # 只使用前两个特征
y = iris.target

2.4 划分训练集和测试集

# 将数据集划分为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

2.5 训练 SVM 模型

# 创建SVM分类器
clf = svm.SVC(kernel='linear')  # 使用线性核函数# 训练模型
clf.fit(X_train, y_train)

2.6 预测与评估

# 在测试集上进行预测
y_pred = clf.predict(X_test)# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")

2.7 可视化结果

# 绘制决策边界
def plot_decision_boundary(X, y, model):h = .02  # 网格步长x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, h),np.arange(y_min, y_max, h))Z = model.predict(np.c_[xx.ravel(), yy.ravel()])Z = Z.reshape(xx.shape)plt.contourf(xx, yy, Z, alpha=0.8)plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k', marker='o')plt.xlabel('Sepal length')plt.ylabel('Sepal width')plt.title('SVM Decision Boundary')plt.show()plot_decision_boundary(X_train, y_train, clf)


http://www.ppmy.cn/embedded/156270.html

相关文章

解决Oracle SQL语句性能问题(10.5)——常用Hint及语法(6)(并行相关Hint)

10.5.3. 常用hint 10.5.3.6. 并行相关Hint 1)parallel:显式的指示优化器为SQL语句或SQL语句中的特定表指定或计算并行度。该Hint具体语法如下所示。 SQL> select|insert|update|delete /*+ parallel[(integer|default|auto|manual)] */ ...; --注: 1)这里,parallel…

图论 八字码

我们可能惊异于某些技巧。我们认为这个技巧真是巧妙啊。或者有人认为我依靠自己的直觉想出了这个表示方法。非常自豪。我认为假设是很小的时候,比如说小学初中,还是不错的。到高中大学,就有一些不成熟了。因为这实际上是一个竞技。很多东西前…

【Elasticsearch】RestClient操作文档

RestClient操作文档 新增文档实体类API语法 查询文档删除文档修改文档批量导入文档小结 新增文档 将数据库中的信息导入elasticsearch中 以商品数据为例 实体类 定义一个索引库结构对应的实体。 Data ApiModel(description "索引库实体") public class ItemDoc{…

django使用踩坑经历

DRF 使用drf获取序列化后的id visitor_serializer VisitorSaveSerializer(data{…}) if visitor_serializer.is_valid():visitor visitor_serializer.save() visitor_id visitor.pkpostgrepsql踩坑 django使用postgrepsql,使用聚合函数如:sum 等,被…

ios文件管理,沙盒机制以及如何操作“文件”APP,把文件共享到文件app

首先,系统是一个整体,那每个app是相互独立的,系统为每个app分配了一定的存储空间,也就是我们说的沙盒,每个app有自己独立的沙盒,文件存储在沙盒中,正常情况下app相互之间数据是不可以共享以及访…

leetcode169.多数元素

给定一个大小为 n 的数组 nums ,返回其中的多数元素。多数元素是指在数组中出现次数 大于 ⌊ n/2 ⌋ 的元素。你可以假设数组是非空的,并且给定的数组总是存在多数元素。 示例 1: 输入:nums [3,2,3] 输出:3 示例 2…

PHP CRM售后系统小程序

💼 CRM售后系统 📺这是一款基于PHP和uniapp深度定制的CRM售后管理系统,它犹如企业的智慧核心,精准赋能销售与售后管理的每一个环节,引领企业步入精细化、数字化的全新管理时代。系统集成了客户管理、合同管理、工单调…

llama 3 笔记

0.简介 llama 3 是在 15 万亿个 Token 上预训练的语言模型,具有 8B 和 70B 两种参数规模,可以支持广泛的用户场景,在各种行业基准上取得了最先进的性能,并提供了一些新功能,包括改进的推理能力。 1.改进亮点 参数规模与模型架构:Llama 3提供了8B和70B两种参数规模的模…