lightgbm做分类

news/2025/2/2 1:58:47/

```python
import pandas as pd#导入csv文件的库
import numpy as np#进行矩阵运算的库
import json#用于读取和写入json数据格式#model lgb分类模型,日志评估,早停防止过拟合
from  lightgbm import LGBMClassifier,log_evaluation,early_stopping
#metric
from sklearn.metrics import roc_auc_score#导入roc_auc曲线
#KFold是直接分成k折,StratifiedKFold还要考虑每种类别的占比
from sklearn.model_selection import StratifiedKFold#config
class Config():seed=2024#随机种子num_folds=10#K折交叉验证TARGET_NAME ='label'#标签
import random#提供了一些用于生成随机数的函数
#设置随机种子,保证模型可以复现
def seed_everything(seed):np.random.seed(seed)#numpy的随机种子random.seed(seed)#python内置的随机种子
seed_everything(Config.seed)path='/kaggle/input/'
#sample: Iki037dt dict_keys(['name', 'normal_data', 'outliers'])
with open(path+"whoiswho-ind-kdd-2024/IND-WhoIsWho/train_author.json") as f:train_author=json.load(f)
#sample : 6IsfnuWU dict_keys(['id', 'title', 'authors', 'abstract', 'keywords', 'venue', 'year'])   
with open(path+"whoiswho-ind-kdd-2024/IND-WhoIsWho/pid_to_info_all.json") as f:pid_to_info=json.load(f)
#efQ8FQ1i dict_keys(['name', 'papers'])
with open(path+"whoiswho-ind-kdd-2024/IND-WhoIsWho/ind_valid_author.json") as f:valid_author=json.load(f)with open(path+"whoiswho-ind-kdd-2024/IND-WhoIsWho/ind_valid_author_submit.json") as f:submission=json.load(f)train_feats=[]
labels=[]
for id,person_info in train_author.items():for text_id in person_info['normal_data']:#正样本feat=pid_to_info[text_id]#['title', 'abstract', 'keywords', 'authors', 'venue', 'year']try:train_feats.append([len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors']),len(feat['keywords']),int(feat['year'])])except:train_feats.append([len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors']),len(feat['keywords']),2000])labels.append(1)for text_id in person_info['outliers']:#负样本feat=pid_to_info[text_id]#['title', 'abstract', 'keywords', 'authors', 'venue', 'year']try:train_feats.append([len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors']),len(feat['keywords']),int(feat['year'])])except:train_feats.append([len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors']),len(feat['keywords']),2000])labels.append(0)   
train_feats=np.array(train_feats)
labels=np.array(labels)
print(f"train_feats.shape:{train_feats.shape},labels.shape:{labels.shape}")
print(f"np.mean(labels):{np.mean(labels)}")
train_feats=pd.DataFrame(train_feats)
train_feats['label']=labels
train_feats.head()valid_feats=[]
for id,person_info in valid_author.items():for text_id in person_info['papers']:feat=pid_to_info[text_id]#['title', 'abstract', 'keywords', 'authors', 'venue', 'year']try:valid_feats.append([len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors']),len(feat['keywords']),int(feat['year'])])except:valid_feats.append([len(feat['title']),len(feat['abstract']),len(feat['keywords']),len(feat['authors']),len(feat['keywords']),2000])
valid_feats=np.array(valid_feats)
print(f"valid_feats.shape:{valid_feats.shape}")
valid_feats=pd.DataFrame(valid_feats)
valid_feats.head()choose_cols=[col for col in valid_feats.columns]
def fit_and_predict(model,train_feats=train_feats,test_feats=valid_feats,name=0):X=train_feats[choose_cols].copy()y=train_feats[Config.TARGET_NAME].copy()test_X=test_feats[choose_cols].copy()oof_pred_pro=np.zeros((len(X),2))test_pred_pro=np.zeros((Config.num_folds,len(test_X),2))#10折交叉验证skf = StratifiedKFold(n_splits=Config.num_folds,random_state=Config.seed, shuffle=True)for fold, (train_index, valid_index) in (enumerate(skf.split(X, y.astype(str)))):print(f"name:{name},fold:{fold}")X_train, X_valid = X.iloc[train_index], X.iloc[valid_index]y_train, y_valid = y.iloc[train_index], y.iloc[valid_index]model.fit(X_train,y_train,eval_set=[(X_valid, y_valid)],callbacks=[log_evaluation(100),early_stopping(100)])oof_pred_pro[valid_index]=model.predict_proba(X_valid)#将数据分批次进行预测.test_pred_pro[fold]=model.predict_proba(test_X)print(f"roc_auc:{roc_auc_score(y.values,oof_pred_pro[:,1])}")return oof_pred_pro,test_pred_pro
#参数来源:https://www.kaggle.com/code/daviddirethucus/home-credit-risk-lightgbm
lgb_params={"boosting_type": "gbdt","objective": "binary","metric": "auc","max_depth": 12,"learning_rate": 0.05,"n_estimators":3072,"colsample_bytree": 0.9,"colsample_bynode": 0.9,"verbose": -1,"random_state": Config.seed,"reg_alpha": 0.1,"reg_lambda": 10,"extra_trees":True,'num_leaves':64,"verbose": -1,"max_bin":255,}lgb_oof_pred_pro,lgb_test_pred_pro=fit_and_predict(model= LGBMClassifier(**lgb_params),name='lgb')
test_preds=lgb_test_pred_pro.mean(axis=0)[:,1]cnt=0
for id,names in submission.items():for name in names:submission[id][name]=test_preds[cnt]cnt+=1
with open('baseline.json', 'w', encoding='utf-8') as f:json.dump(submission, f, ensure_ascii=False, indent=4)


http://www.ppmy.cn/news/1568567.html

相关文章

2006-2021年 省级数字经济与实体经济融合水平计算代码及原始数据-社科数据

省级数字经济与实体经济融合水平计算代码及原始数据2006-2021年-社科数据https://download.csdn.net/download/paofuluolijiang/90028609 https://download.csdn.net/download/paofuluolijiang/90028609 数字经济与实体经济的融合是推动现代经济发展的关键力量。从2006年至20…

Vue.js 生命周期钩子在 Composition API 中的应用

Vue.js 生命周期钩子在 Composition API 中的应用 今天我们来聊聊在 Vue 3 的组合式 API(Composition API)中,如何使用生命周期钩子。如果你对如何在 setup() 函数中处理组件的生命周期事件感到困惑,那么这篇文章将为你解答。 什…

vulfocus/thinkphp:6.0.12 命令执行

本次测试是在vulfocus靶场上进行 漏洞介绍 在其6.0.13版本及以前,存在一处本地文件包含漏洞。当多语言特性被开启时,攻击者可以使用lang参数来包含任意PHP文件。 虽然只能包含本地PHP文件,但在开启了register_argc_argv且安装了pcel/pear的环境下,可以包含/usr/local/lib/…

遗传算法【Genetic Algorithm(GA)】求解函数最大值(MATLAB and Python实现)

一、遗传算法基础知识 来自B站视频的笔记: 【超容易理解】手把手逐句带你解读并实现遗传算法的MATLAB编程(结合理论基础)_哔哩哔哩_bilibili 1、遗传算法 使用“适者生存”的原则,在遗传算法的每一代中,…

MATLAB中extractAfter函数用法

目录 语法 说明 示例 选择子字符串后的文本 使用模式提取路径后的文件名 选择指定位置后的子字符串 选择字符向量中位置之后的文本 extractAfter函数的用法是提取指定位置后的子字符串。 语法 newStr extractAfter(str,pat) newStr extractAfter(str,pos) 说明 new…

qt-Quick3D笔记之官方例程Runtimeloader Example运行笔记

qt-Quick3D笔记之官方例程Runtimeloader Example运行笔记 文章目录 qt-Quick3D笔记之官方例程Runtimeloader Example运行笔记1.例程运行效果2.例程缩略图3.项目文件列表4.main.qml5.main.cpp6.CMakeLists.txt 1.例程运行效果 运行该项目需要自己准备一个模型文件 2.例程缩略图…

【内蒙古乡镇界】面图层shp格式+乡镇名称和编码wgs84坐标无偏移arcgis数据内容测评

最新2020年乡镇界面图层shp格式arcgis数据乡镇名称和编码wgs84坐标无偏移。arcgis直接打开,单独乡镇界一个图层。品质高

【贪心算法】在有盾牌的情况下能通过每轮伤害的最小值(亚马逊笔试题)

思路&#xff1a; 采用贪心算法&#xff0c;先计算出来所有的伤害值&#xff0c;然后再计算每轮在使用盾牌的情况下能减少伤害的最大值&#xff0c;最后用总的伤害值减去能减少的最大值就是最少的总伤害值 public static long getMinimumValue(List<Integer> power, int…