用 C 语言进行大模型推理:探索 llama2.c 仓库(二)

devtools/2024/9/25 4:34:44/
cle class="baidu_pl">
cle_content" class="article_content clearfix">
content_views" class="markdown_views prism-atom-one-dark">cap="round" d="M5,0 0,2.5 5,5z" id="raphael-marker-block" style="-webkit-tap-highlight-color: rgba(0, 0, 0, 0);">

class="toc">

文章目录

  • 前提
  • 如何构建一个Transformer Model
    • 模型定义
    • 模型初始化
  • 如何构建tokenzier 和 sampler
  • 如何进行推理
  • 总结

前提

上一节我们介绍了class="tags" href="/LLAMA2.html" title=llama2>llama2.c中如何对hugging face的权重进行处理࿰c;拿到了class="tags" href="/LLAMA2.html" title=llama2>llama2.c想要的权重格式和tokenizer.bin格式。这一节我们分析下在class="tags" href="/LLAMA2.html" title=llama2>llama2.c如何解析这两个<code>.bincode>文件。这一节的所有代码都在<code>run.ccode>文件里。

用 C 语言进行大模型推理:探索 class="tags" href="/LLAMA2.html" title=llama2>llama2.c 仓库(一)

如何构建一个Transformer Model

按照一个最简单地理解࿰c;我们可以使用C语言构建一个Transformer Model࿰c;然后将两个.bin文件按照格式填进去即可。那这个Transformer Model 应该是一个什么数据结构呢࿰c;或者是一个什么样的组织架构呢?在C语言中没有<code>classcode>这个概念的࿰c;最多我们常见的也就是结构体了࿰c;而且结构体里只能定义变量࿰c;不能定义函数。所以那些操作Transformer Model中的那些算子又该如何实现呢?带着这些问题࿰c;或者你还有其他的问题࿰c;我们一步一步来看下class="tags" href="/LLAMA2.html" title=llama2>llama2.c中是如何实现的。

模型定义

<code class="prism language-c">class="token keyword">typedef class="token keyword">struct class="token punctuation">{class="token keyword">int dimclass="token punctuation">;        class="token comment">// transformer dimensionclass="token keyword">int hidden_dimclass="token punctuation">; class="token comment">// for ffn layersclass="token keyword">int n_layersclass="token punctuation">;   class="token comment">// number of layersclass="token keyword">int n_headsclass="token punctuation">;    class="token comment">// number of query headsclass="token keyword">int n_kv_headsclass="token punctuation">; class="token comment">// number of key/value heads (can be < query heads because ofclass="token comment">// multiquery)class="token keyword">int vocab_sizeclass="token punctuation">; class="token comment">// vocabulary size, usually 256 (byte-level)class="token keyword">int seq_lenclass="token punctuation">;    class="token comment">// max sequence length
class="token punctuation">} Configclass="token punctuation">;class="token keyword">typedef class="token keyword">struct class="token punctuation">{class="token comment">// token embedding tableclass="token keyword">float class="token operator">*token_embedding_tableclass="token punctuation">; class="token comment">// (vocab_size, dim)class="token comment">// weights for rmsnormsclass="token keyword">float class="token operator">*rms_att_weightclass="token punctuation">; class="token comment">// (layer, dim) rmsnorm weightsclass="token keyword">float class="token operator">*rms_ffn_weightclass="token punctuation">; class="token comment">// (layer, dim)class="token comment">// weights for matmuls. note dim == n_heads * head_sizeclass="token keyword">float class="token operator">*wqclass="token punctuation">; class="token comment">// (layer, dim, n_heads * head_size)class="token keyword">float class="token operator">*wkclass="token punctuation">; class="token comment">// (layer, dim, n_kv_heads * head_size)class="token keyword">float class="token operator">*wvclass="token punctuation">; class="token comment">// (layer, dim, n_kv_heads * head_size)class="token keyword">float class="token operator">*woclass="token punctuation">; class="token comment">// (layer, n_heads * head_size, dim)class="token comment">// weights for ffnclass="token keyword">float class="token operator">*w1class="token punctuation">; class="token comment">// (layer, hidden_dim, dim)class="token keyword">float class="token operator">*w2class="token punctuation">; class="token comment">// (layer, dim, hidden_dim)class="token keyword">float class="token operator">*w3class="token punctuation">; class="token comment">// (layer, hidden_dim, dim)class="token comment">// final rmsnormclass="token keyword">float class="token operator">*rms_final_weightclass="token punctuation">; class="token comment">// (dim,)class="token comment">// (optional) classifier weights for the logits, on the last layerclass="token keyword">float class="token operator">*wclsclass="token punctuation">;
class="token punctuation">} TransformerWeightsclass="token punctuation">;class="token keyword">typedef class="token keyword">struct class="token punctuation">{class="token comment">// current wave of activationsclass="token keyword">float class="token operator">*xclass="token punctuation">;      class="token comment">// activation at current time stamp (dim,)class="token keyword">float class="token operator">*xbclass="token punctuation">;     class="token comment">// same, but inside a residual branch (dim,)class="token keyword">float class="token operator">*xb2class="token punctuation">;    class="token comment">// an additional buffer just for convenience (dim,)class="token keyword">float class="token operator">*hbclass="token punctuation">;     class="token comment">// buffer for hidden dimension in the ffn (hidden_dim,)class="token keyword">float class="token operator">*hb2class="token punctuation">;    class="token comment">// buffer for hidden dimension in the ffn (hidden_dim,)class="token keyword">float class="token operator">*qclass="token punctuation">;      class="token comment">// query (dim,)class="token keyword">float class="token operator">*kclass="token punctuation">;      class="token comment">// key (dim,)class="token keyword">float class="token operator">*vclass="token punctuation">;      class="token comment">// value (dim,)class="token keyword">float class="token operator">*attclass="token punctuation">;    class="token comment">// buffer for scores/attention values (n_heads, seq_len)class="token keyword">float class="token operator">*logitsclass="token punctuation">; class="token comment">// output logitsclass="token comment">// kv cacheclass="token keyword">float class="token operator">*key_cacheclass="token punctuation">;   class="token comment">// (layer, seq_len, dim)class="token keyword">float class="token operator">*value_cacheclass="token punctuation">; class="token comment">// (layer, seq_len, dim)
class="token punctuation">} RunStateclass="token punctuation">;class="token keyword">typedef class="token keyword">struct class="token punctuation">{Config configclass="token punctuation">; class="token comment">// the hyperparameters of the architecture (the blueprint)TransformerWeights weightsclass="token punctuation">; class="token comment">// the weights of the modelRunState stateclass="token punctuation">; class="token comment">// buffers for the "wave" of activations in the forward passclass="token comment">// some more state needed to properly clean up the memory mapping (sigh)class="token keyword">int fdclass="token punctuation">;            class="token comment">// file descriptor for memory mappingclass="token keyword">float class="token operator">*dataclass="token punctuation">;       class="token comment">// memory mapped data pointerclass="token class-name">ssize_t file_sizeclass="token punctuation">; class="token comment">// size of the checkpoint file in bytes
class="token punctuation">} Transformerclass="token punctuation">;
code>

<code>class="tags" href="/LLAMA2.html" title=llama2>llama2.ccode>中的Transformer是一个结构体࿰c;其中最重要的三个成员变量是<code>configcode>࿰c;<code>weightscode>࿰c;<code>statecode>࿰c;分别保存了网络的超参数࿰c;权重࿰c;以及网络运行过程中的中间结果。
强烈建议这里你仔细理解理解࿰c;体会一下这个写法。

模型初始化

我们要对定义的模型进行初始化࿰c;主要是两个方面:权重初始化和中间变量初始化。这里<code>class="tags" href="/LLAMA2.html" title=llama2>llama2.ccode>的写法就更厉害了。请仔细欣赏下面的两个函数:

权重初始化函数:

<code class="prism language-c">class="token keyword">void class="token function">memory_map_weightsclass="token punctuation">(TransformerWeights class="token operator">*wclass="token punctuation">, Config class="token operator">*pclass="token punctuation">, class="token keyword">float class="token operator">*ptrclass="token punctuation">,class="token keyword">int shared_weightsclass="token punctuation">) class="token punctuation">{class="token keyword">int head_size class="token operator">= pclass="token operator">->dim class="token operator">/ pclass="token operator">->n_headsclass="token punctuation">;class="token comment">// make sure the multiplications below are done in 64bit to fit the parameterclass="token comment">// counts of 13B+ modelsclass="token keyword">unsigned class="token keyword">long class="token keyword">long n_layers class="token operator">= pclass="token operator">->n_layersclass="token punctuation">;wclass="token operator">->token_embedding_table class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= pclass="token operator">->vocab_size class="token operator">* pclass="token operator">->dimclass="token punctuation">;wclass="token operator">->rms_att_weight class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dimclass="token punctuation">;wclass="token operator">->wq class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dim class="token operator">* class="token punctuation">(pclass="token operator">->n_heads class="token operator">* head_sizeclass="token punctuation">)class="token punctuation">;wclass="token operator">->wk class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dim class="token operator">* class="token punctuation">(pclass="token operator">->n_kv_heads class="token operator">* head_sizeclass="token punctuation">)class="token punctuation">;wclass="token operator">->wv class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dim class="token operator">* class="token punctuation">(pclass="token operator">->n_kv_heads class="token operator">* head_sizeclass="token punctuation">)class="token punctuation">;wclass="token operator">->wo class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* class="token punctuation">(pclass="token operator">->n_heads class="token operator">* head_sizeclass="token punctuation">) class="token operator">* pclass="token operator">->dimclass="token punctuation">;wclass="token operator">->rms_ffn_weight class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dimclass="token punctuation">;wclass="token operator">->w1 class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dim class="token operator">* pclass="token operator">->hidden_dimclass="token punctuation">;wclass="token operator">->w2 class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->hidden_dim class="token operator">* pclass="token operator">->dimclass="token punctuation">;wclass="token operator">->w3 class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dim class="token operator">* pclass="token operator">->hidden_dimclass="token punctuation">;wclass="token operator">->rms_final_weight class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= pclass="token operator">->dimclass="token punctuation">;ptr class="token operator">+= pclass="token operator">->seq_len class="token operator">* head_size class="token operator">/class="token number">2class="token punctuation">; class="token comment">// skip what used to be freq_cis_real (for RoPE)ptr class="token operator">+= pclass="token operator">->seq_len class="token operator">* head_size class="token operator">/class="token number">2class="token punctuation">; class="token comment">// skip what used to be freq_cis_imag (for RoPE)wclass="token operator">->wcls class="token operator">= shared_weights class="token operator">? wclass="token operator">->token_embedding_table class="token operator">: ptrclass="token punctuation">;
class="token punctuation">}
code>

自我感觉这个仓库很经典得一段代码就是这里了࿰c;我没有加载权重吧࿰c;我只是拿到了它的地址࿰c;然后映射给我结构体中的变量。然后等我真正推理计算的时候࿰c;用到哪一段权重就将哪一段权重加载到内存中参与计算。

中间变量初始化:

<code class="prism language-cpp">class="token keyword">void class="token function">malloc_run_stateclass="token punctuation">(RunState class="token operator">*sclass="token punctuation">, Config class="token operator">*pclass="token punctuation">) class="token punctuation">{class="token comment">// we calloc instead of malloc to keep valgrind happyclass="token keyword">int kv_dim class="token operator">= class="token punctuation">(pclass="token operator">->dim class="token operator">* pclass="token operator">->n_kv_headsclass="token punctuation">) class="token operator">/ pclass="token operator">->n_headsclass="token punctuation">;sclass="token operator">->x class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->xb class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->xb2 class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->hb class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->hidden_dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->hb2 class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->hidden_dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->q class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->key_cache class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->n_layers class="token operator">* pclass="token operator">->seq_len class="token operator">* kv_dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->value_cache class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->n_layers class="token operator">* pclass="token operator">->seq_len class="token operator">* kv_dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->att class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->n_heads class="token operator">* pclass="token operator">->seq_lenclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->logits class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->vocab_sizeclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;class="token comment">// ensure all mallocs went fineclass="token keyword">if class="token punctuation">(class="token operator">!sclass="token operator">->x class="token operator">|| class="token operator">!sclass="token operator">->xb class="token operator">|| class="token operator">!sclass="token operator">->xb2 class="token operator">|| class="token operator">!sclass="token operator">->hb class="token operator">|| class="token operator">!sclass="token operator">->hb2 class="token operator">|| class="token operator">!sclass="token operator">->q class="token operator">||class="token operator">!sclass="token operator">->key_cache class="token operator">|| class="token operator">!sclass="token operator">->value_cache class="token operator">|| class="token operator">!sclass="token operator">->att class="token operator">|| class="token operator">!sclass="token operator">->logitsclass="token punctuation">) class="token punctuation">{class="token function">fprintfclass="token punctuation">(class="token constant">stderrclass="token punctuation">, class="token string">"malloc failed!\n"class="token punctuation">)class="token punctuation">;class="token function">exitclass="token punctuation">(EXIT_FAILUREclass="token punctuation">)class="token punctuation">;class="token punctuation">}
class="token punctuation">}
code>

如果不太理解权重初始化和中间变量初始化时为什么要申请那么大的空间࿰c;可以自己手动地将网络地数据流从头到尾推一遍。

如何构建tokenzier 和 sampler

对于这两个模块地构建我们不多介绍࿰c;感兴趣地可以自己去看看源码。

如何进行推理

这部分是我最感兴趣的地方。

<code class="prism language-c">  class="token comment">// forward all the layersclass="token keyword">for class="token punctuation">(class="token keyword">unsigned class="token keyword">long class="token keyword">long l class="token operator">= class="token number">0class="token punctuation">; l class="token operator">< pclass="token operator">->n_layersclass="token punctuation">; lclass="token operator">++class="token punctuation">) class="token punctuation">{class="token comment">// attention rmsnormclass="token function">rmsnormclass="token punctuation">(sclass="token operator">->xbclass="token punctuation">, xclass="token punctuation">, wclass="token operator">->rms_att_weight class="token operator">+ l class="token operator">* dimclass="token punctuation">, dimclass="token punctuation">)class="token punctuation">;class="token comment">// key and value point to the kv cacheclass="token keyword">int loff class="token operator">= l class="token operator">* pclass="token operator">->seq_len class="token operator">* kv_dimclass="token punctuation">; class="token comment">// kv cache layer offset for conveniencesclass="token operator">->k class="token operator">= sclass="token operator">->key_cache class="token operator">+ loff class="token operator">+ pos class="token operator">* kv_dimclass="token punctuation">;sclass="token operator">->v class="token operator">= sclass="token operator">->value_cache class="token operator">+ loff class="token operator">+ pos class="token operator">* kv_dimclass="token punctuation">;class="token comment">// qkv matmuls for this positionclass="token function">matmulclass="token punctuation">(sclass="token operator">->qclass="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->wq class="token operator">+ l class="token operator">* dim class="token operator">* dimclass="token punctuation">, dimclass="token punctuation">, dimclass="token punctuation">)class="token punctuation">;class="token function">matmulclass="token punctuation">(sclass="token operator">->kclass="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->wk class="token operator">+ l class="token operator">* dim class="token operator">* kv_dimclass="token punctuation">, dimclass="token punctuation">, kv_dimclass="token punctuation">)class="token punctuation">;class="token function">matmulclass="token punctuation">(sclass="token operator">->vclass="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->wv class="token operator">+ l class="token operator">* dim class="token operator">* kv_dimclass="token punctuation">, dimclass="token punctuation">, kv_dimclass="token punctuation">)class="token punctuation">;class="token comment">// RoPE relative positional encoding: complex-valued rotate q and k in eachclass="token comment">// headclass="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< dimclass="token punctuation">; i class="token operator">+= class="token number">2class="token punctuation">) class="token punctuation">{class="token keyword">int head_dim class="token operator">= i class="token operator">% head_sizeclass="token punctuation">;class="token keyword">float freq class="token operator">= class="token number">1.0f class="token operator">/ class="token function">powfclass="token punctuation">(class="token number">10000.0fclass="token punctuation">, head_dim class="token operator">/ class="token punctuation">(class="token keyword">floatclass="token punctuation">)head_sizeclass="token punctuation">)class="token punctuation">;class="token keyword">float val class="token operator">= pos class="token operator">* freqclass="token punctuation">;class="token keyword">float fcr class="token operator">= class="token function">cosfclass="token punctuation">(valclass="token punctuation">)class="token punctuation">;class="token keyword">float fci class="token operator">= class="token function">sinfclass="token punctuation">(valclass="token punctuation">)class="token punctuation">;class="token keyword">int rotn class="token operator">= i class="token operator">< kv_dim class="token operator">? class="token number">2 class="token operator">: class="token number">1class="token punctuation">; class="token comment">// how many vectors? 2 = q & k, 1 = q onlyclass="token keyword">for class="token punctuation">(class="token keyword">int v class="token operator">= class="token number">0class="token punctuation">; v class="token operator">< rotnclass="token punctuation">; vclass="token operator">++class="token punctuation">) class="token punctuation">{class="token keyword">float class="token operator">*vec class="token operator">=v class="token operator">== class="token number">0 class="token operator">? sclass="token operator">->q class="token operator">: sclass="token operator">->kclass="token punctuation">; class="token comment">// the vector to rotate (query or key)class="token keyword">float v0 class="token operator">= vecclass="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token keyword">float v1 class="token operator">= vecclass="token punctuation">[i class="token operator">+ class="token number">1class="token punctuation">]class="token punctuation">;vecclass="token punctuation">[iclass="token punctuation">] class="token operator">= v0 class="token operator">* fcr class="token operator">- v1 class="token operator">* fciclass="token punctuation">;vecclass="token punctuation">[i class="token operator">+ class="token number">1class="token punctuation">] class="token operator">= v0 class="token operator">* fci class="token operator">+ v1 class="token operator">* fcrclass="token punctuation">;class="token punctuation">}class="token punctuation">}class="token comment">// multihead attention. iterate over all headsclass="token keyword">int hclass="token punctuation">;
class="token macro property">class="token directive-hash">#class="token directive keyword">pragma class="token expression">omp parallel class="token keyword">for class="token function">privateclass="token punctuation">(hclass="token punctuation">)class="token keyword">for class="token punctuation">(h class="token operator">= class="token number">0class="token punctuation">; h class="token operator">< pclass="token operator">->n_headsclass="token punctuation">; hclass="token operator">++class="token punctuation">) class="token punctuation">{class="token comment">// get the query vector for this headclass="token keyword">float class="token operator">*q class="token operator">= sclass="token operator">->q class="token operator">+ h class="token operator">* head_sizeclass="token punctuation">;class="token comment">// attention scores for this headclass="token keyword">float class="token operator">*att class="token operator">= sclass="token operator">->att class="token operator">+ h class="token operator">* pclass="token operator">->seq_lenclass="token punctuation">;class="token comment">// iterate over all timesteps, including the current oneclass="token keyword">for class="token punctuation">(class="token keyword">int t class="token operator">= class="token number">0class="token punctuation">; t class="token operator"><= posclass="token punctuation">; tclass="token operator">++class="token punctuation">) class="token punctuation">{class="token comment">// get the key vector for this head and at this timestepclass="token keyword">float class="token operator">*k class="token operator">= sclass="token operator">->key_cache class="token operator">+ loff class="token operator">+ t class="token operator">* kv_dim class="token operator">+ class="token punctuation">(h class="token operator">/ kv_mulclass="token punctuation">) class="token operator">* head_sizeclass="token punctuation">;class="token comment">// calculate the attention score as the dot product of q and kclass="token keyword">float score class="token operator">= class="token number">0.0fclass="token punctuation">;class="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< head_sizeclass="token punctuation">; iclass="token operator">++class="token punctuation">) class="token punctuation">{score class="token operator">+= qclass="token punctuation">[iclass="token punctuation">] class="token operator">* kclass="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token punctuation">}score class="token operator">/= class="token function">sqrtfclass="token punctuation">(head_sizeclass="token punctuation">)class="token punctuation">;class="token comment">// save the score to the attention bufferattclass="token punctuation">[tclass="token punctuation">] class="token operator">= scoreclass="token punctuation">;class="token punctuation">}class="token comment">// softmax the scores to get attention weights, from 0..pos inclusivelyclass="token function">softmaxclass="token punctuation">(attclass="token punctuation">, pos class="token operator">+ class="token number">1class="token punctuation">)class="token punctuation">;class="token comment">// weighted sum of the values, store back into xbclass="token keyword">float class="token operator">*xb class="token operator">= sclass="token operator">->xb class="token operator">+ h class="token operator">* head_sizeclass="token punctuation">;class="token function">memsetclass="token punctuation">(xbclass="token punctuation">, class="token number">0class="token punctuation">, head_size class="token operator">* class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;class="token keyword">for class="token punctuation">(class="token keyword">int t class="token operator">= class="token number">0class="token punctuation">; t class="token operator"><= posclass="token punctuation">; tclass="token operator">++class="token punctuation">) class="token punctuation">{class="token comment">// get the value vector for this head and at this timestepclass="token keyword">float class="token operator">*v class="token operator">=sclass="token operator">->value_cache class="token operator">+ loff class="token operator">+ t class="token operator">* kv_dim class="token operator">+ class="token punctuation">(h class="token operator">/ kv_mulclass="token punctuation">) class="token operator">* head_sizeclass="token punctuation">;class="token comment">// get the attention weight for this timestepclass="token keyword">float a class="token operator">= attclass="token punctuation">[tclass="token punctuation">]class="token punctuation">;class="token comment">// accumulate the weighted value into xbclass="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< head_sizeclass="token punctuation">; iclass="token operator">++class="token punctuation">) class="token punctuation">{xbclass="token punctuation">[iclass="token punctuation">] class="token operator">+= a class="token operator">* vclass="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token punctuation">}class="token punctuation">}class="token punctuation">}class="token comment">// final matmul to get the output of the attentionclass="token function">matmulclass="token punctuation">(sclass="token operator">->xb2class="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->wo class="token operator">+ l class="token operator">* dim class="token operator">* dimclass="token punctuation">, dimclass="token punctuation">, dimclass="token punctuation">)class="token punctuation">;class="token comment">// residual connection back into xclass="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< dimclass="token punctuation">; iclass="token operator">++class="token punctuation">) class="token punctuation">{xclass="token punctuation">[iclass="token punctuation">] class="token operator">+= sclass="token operator">->xb2class="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token punctuation">}class="token comment">// ffn rmsnormclass="token function">rmsnormclass="token punctuation">(sclass="token operator">->xbclass="token punctuation">, xclass="token punctuation">, wclass="token operator">->rms_ffn_weight class="token operator">+ l class="token operator">* dimclass="token punctuation">, dimclass="token punctuation">)class="token punctuation">;class="token comment">// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))class="token comment">// first calculate self.w1(x) and self.w3(x)class="token function">matmulclass="token punctuation">(sclass="token operator">->hbclass="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->w1 class="token operator">+ l class="token operator">* dim class="token operator">* hidden_dimclass="token punctuation">, dimclass="token punctuation">, hidden_dimclass="token punctuation">)class="token punctuation">;class="token function">matmulclass="token punctuation">(sclass="token operator">->hb2class="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->w3 class="token operator">+ l class="token operator">* dim class="token operator">* hidden_dimclass="token punctuation">, dimclass="token punctuation">, hidden_dimclass="token punctuation">)class="token punctuation">;class="token comment">// SwiGLU non-linearityclass="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< hidden_dimclass="token punctuation">; iclass="token operator">++class="token punctuation">) class="token punctuation">{class="token keyword">float val class="token operator">= sclass="token operator">->hbclass="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token comment">// silu(x)=x*σ(x), where σ(x) is the logistic sigmoidval class="token operator">*= class="token punctuation">(class="token number">1.0f class="token operator">/ class="token punctuation">(class="token number">1.0f class="token operator">+ class="token function">expfclass="token punctuation">(class="token operator">-valclass="token punctuation">)class="token punctuation">)class="token punctuation">)class="token punctuation">;class="token comment">// elementwise multiply with w3(x)val class="token operator">*= sclass="token operator">->hb2class="token punctuation">[iclass="token punctuation">]class="token punctuation">;sclass="token operator">->hbclass="token punctuation">[iclass="token punctuation">] class="token operator">= valclass="token punctuation">;class="token punctuation">}class="token comment">// final matmul to get the output of the ffnclass="token function">matmulclass="token punctuation">(sclass="token operator">->xbclass="token punctuation">, sclass="token operator">->hbclass="token punctuation">, wclass="token operator">->w2 class="token operator">+ l class="token operator">* dim class="token operator">* hidden_dimclass="token punctuation">, hidden_dimclass="token punctuation">, dimclass="token punctuation">)class="token punctuation">;class="token comment">// residual connectionclass="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< dimclass="token punctuation">; iclass="token operator">++class="token punctuation">) class="token punctuation">{xclass="token punctuation">[iclass="token punctuation">] class="token operator">+= sclass="token operator">->xbclass="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token punctuation">}class="token punctuation">}
code>

<code>forcode>循环所有的<code>layerscode>进行推理࿰c;有三个主要的子函数࿰c;分别是:<code>rmsnormcode>࿰c;<code>matmulcode>࿰c;<code>softmaxcode>࿰c;分别对应着三个算子࿰c;其他的算子则是直接在<code>forcode>循环内实现的。所有的<code>layercode>都计算一遍后࿰c;再加上后处理即可完成一个<code>tokencode>的推理。

总结

总得来说࿰c;这个库还是有很多的东西值得我们去学习的࿰c;学习下大神的编码思维和编码方式。


http://www.ppmy.cn/devtools/38976.html

相关文章

华为eNSP学习—IP编址

IP编址 IP编址子网划分例题展示第一步:机房1的子网划分第二步:机房2的子网划分第三步:机房3的子网划分IP编址 明确:IPv4地址长度32bit,点分十进制的形式 ip地址构成=网络位+主机位 子网掩码区分网络位和主机位 学此篇基础: ①学会十进制与二进制转换 ②学会区分网络位和…

elementUi中的el-table合计行添加点击事件

elementUi 文档中&#xff0c;合计行并没有点击事件&#xff0c;这里自己实现了合计行的点击事件。 created() {this.propertyList [{ property: order, label: 序号 },{ property: deptName, label: 单位名称 },{ property: contentPublishQuantity, label: 文章数量 },{ pro…

Nanopc T4 使用OpenCV

识别长方形&#xff1a; import cv2 import cv2 as cv import time import platform import os# 获取操作系统类型 os_type platform.system() if os_type "Windows":# Windows系统cap cv.VideoCapture(0) # 使用第零个摄像头 elif os_type "Linux"…

数据结构之单链表之环形链表

1.题目 题目&#xff1a;给定一个链表的头节点 head &#xff0c;返回链表开始入环的第一个节点。 如果链表无环&#xff0c;则返回 null。 2.分析 首先&#xff0c;我们应该判断链表是否有环&#xff0c;这个可以根据我的上一篇文章的快慢指针来判断。 bool hasCycle(stru…

现货黄金今日行情分析:昨日高低点法

进行交易之前&#xff0c;投资者要对现货黄金今日行情进行一波分析&#xff0c;我们交易决策应该建立在合理分析的基础之上。那么打开市场交易软件看到现货黄金今日行情之后&#xff0c;该如何着手进行分析呢&#xff1f;下面我们就来讨论一下具体的方法。 要进行现货黄金今日行…

【JavaEE网络】HTTP/HTTPS协议的工作原理与格式详解

目录 HTTP/HTTPSHTTP是什么理解“应用层协议”理解HTTP协议的工作过程HTTP协议格式 HTTP/HTTPS HTTP是什么 应用层&#xff0c;一方面是需要自定义协议&#xff0c;一方面也会用到一些现成的协议 HTTP及HTTPS是应用层重点协议 使用浏览器&#xff0c;打开网站&#xff0c;这…

【StarRocks系列】 Trino 方言支持

我们在之前的文章中&#xff0c;介绍了 Doris 官方提供的两种方言转换工具&#xff0c;分别是 sql convertor 和方言 plugin。StarRocks 目前同样也提供了类似的方言转换功能。本文我们就一起来看一下这个功能的实现与 Doris 相比有何不同。 一、Trino 方言验证 我们可以通过…

Redis-6 三种集群模式:主从模式、哨兵模式、分片集群

主从模式 一.介绍一下redis的主从同步 单节点的redis的并发能力是有上限的&#xff0c;要实现高并发&#xff0c;就要搭建主从集群&#xff0c;实现读写分离。通常是一主多从&#xff0c;主节点负责写数据&#xff0c;从节点负责读数据。 二.介绍一下主从模式同步数据的流程…