content_views"
class="markdown_views prism-atom-one-dark">
前提
上一节我们介绍了class="tags" href="/LLAMA2.html" title=llama2>llama2.c中如何对hugging face的权重进行处理c;拿到了class="tags" href="/LLAMA2.html" title=llama2>llama2.c想要的权重格式和tokenizer.bin格式。这一节我们分析下在class="tags" href="/LLAMA2.html" title=llama2>llama2.c如何解析这两个<code>.bincode>文件。这一节的所有代码都在<code>run.ccode>文件里。
用 C 语言进行大模型推理:探索 class="tags" href="/LLAMA2.html" title=llama2>llama2.c 仓库(一)
如何构建一个Transformer Model
按照一个最简单地理解c;我们可以使用C语言构建一个Transformer Modelc;然后将两个.bin文件按照格式填进去即可。那这个Transformer Model 应该是一个什么数据结构呢c;或者是一个什么样的组织架构呢?在C语言中没有<code>classcode>这个概念的c;最多我们常见的也就是结构体了c;而且结构体里只能定义变量c;不能定义函数。所以那些操作Transformer Model中的那些算子又该如何实现呢?带着这些问题c;或者你还有其他的问题c;我们一步一步来看下class="tags" href="/LLAMA2.html" title=llama2>llama2.c中是如何实现的。
模型定义
<code class="prism language-c">class="token keyword">typedef class="token keyword">struct class="token punctuation">{class="token keyword">int dimclass="token punctuation">; class="token comment">// transformer dimensionclass="token keyword">int hidden_dimclass="token punctuation">; class="token comment">// for ffn layersclass="token keyword">int n_layersclass="token punctuation">; class="token comment">// number of layersclass="token keyword">int n_headsclass="token punctuation">; class="token comment">// number of query headsclass="token keyword">int n_kv_headsclass="token punctuation">; class="token comment">// number of key/value heads (can be < query heads because ofclass="token comment">// multiquery)class="token keyword">int vocab_sizeclass="token punctuation">; class="token comment">// vocabulary size, usually 256 (byte-level)class="token keyword">int seq_lenclass="token punctuation">; class="token comment">// max sequence length
class="token punctuation">} Configclass="token punctuation">;class="token keyword">typedef class="token keyword">struct class="token punctuation">{class="token comment">// token embedding tableclass="token keyword">float class="token operator">*token_embedding_tableclass="token punctuation">; class="token comment">// (vocab_size, dim)class="token comment">// weights for rmsnormsclass="token keyword">float class="token operator">*rms_att_weightclass="token punctuation">; class="token comment">// (layer, dim) rmsnorm weightsclass="token keyword">float class="token operator">*rms_ffn_weightclass="token punctuation">; class="token comment">// (layer, dim)class="token comment">// weights for matmuls. note dim == n_heads * head_sizeclass="token keyword">float class="token operator">*wqclass="token punctuation">; class="token comment">// (layer, dim, n_heads * head_size)class="token keyword">float class="token operator">*wkclass="token punctuation">; class="token comment">// (layer, dim, n_kv_heads * head_size)class="token keyword">float class="token operator">*wvclass="token punctuation">; class="token comment">// (layer, dim, n_kv_heads * head_size)class="token keyword">float class="token operator">*woclass="token punctuation">; class="token comment">// (layer, n_heads * head_size, dim)class="token comment">// weights for ffnclass="token keyword">float class="token operator">*w1class="token punctuation">; class="token comment">// (layer, hidden_dim, dim)class="token keyword">float class="token operator">*w2class="token punctuation">; class="token comment">// (layer, dim, hidden_dim)class="token keyword">float class="token operator">*w3class="token punctuation">; class="token comment">// (layer, hidden_dim, dim)class="token comment">// final rmsnormclass="token keyword">float class="token operator">*rms_final_weightclass="token punctuation">; class="token comment">// (dim,)class="token comment">// (optional) classifier weights for the logits, on the last layerclass="token keyword">float class="token operator">*wclsclass="token punctuation">;
class="token punctuation">} TransformerWeightsclass="token punctuation">;class="token keyword">typedef class="token keyword">struct class="token punctuation">{class="token comment">// current wave of activationsclass="token keyword">float class="token operator">*xclass="token punctuation">; class="token comment">// activation at current time stamp (dim,)class="token keyword">float class="token operator">*xbclass="token punctuation">; class="token comment">// same, but inside a residual branch (dim,)class="token keyword">float class="token operator">*xb2class="token punctuation">; class="token comment">// an additional buffer just for convenience (dim,)class="token keyword">float class="token operator">*hbclass="token punctuation">; class="token comment">// buffer for hidden dimension in the ffn (hidden_dim,)class="token keyword">float class="token operator">*hb2class="token punctuation">; class="token comment">// buffer for hidden dimension in the ffn (hidden_dim,)class="token keyword">float class="token operator">*qclass="token punctuation">; class="token comment">// query (dim,)class="token keyword">float class="token operator">*kclass="token punctuation">; class="token comment">// key (dim,)class="token keyword">float class="token operator">*vclass="token punctuation">; class="token comment">// value (dim,)class="token keyword">float class="token operator">*attclass="token punctuation">; class="token comment">// buffer for scores/attention values (n_heads, seq_len)class="token keyword">float class="token operator">*logitsclass="token punctuation">; class="token comment">// output logitsclass="token comment">// kv cacheclass="token keyword">float class="token operator">*key_cacheclass="token punctuation">; class="token comment">// (layer, seq_len, dim)class="token keyword">float class="token operator">*value_cacheclass="token punctuation">; class="token comment">// (layer, seq_len, dim)
class="token punctuation">} RunStateclass="token punctuation">;class="token keyword">typedef class="token keyword">struct class="token punctuation">{Config configclass="token punctuation">; class="token comment">// the hyperparameters of the architecture (the blueprint)TransformerWeights weightsclass="token punctuation">; class="token comment">// the weights of the modelRunState stateclass="token punctuation">; class="token comment">// buffers for the "wave" of activations in the forward passclass="token comment">// some more state needed to properly clean up the memory mapping (sigh)class="token keyword">int fdclass="token punctuation">; class="token comment">// file descriptor for memory mappingclass="token keyword">float class="token operator">*dataclass="token punctuation">; class="token comment">// memory mapped data pointerclass="token class-name">ssize_t file_sizeclass="token punctuation">; class="token comment">// size of the checkpoint file in bytes
class="token punctuation">} Transformerclass="token punctuation">;
code>
<code>class="tags" href="/LLAMA2.html" title=llama2>llama2.ccode>中的Transformer是一个结构体c;其中最重要的三个成员变量是<code>configcode>c;<code>weightscode>c;<code>statecode>c;分别保存了网络的超参数c;权重c;以及网络运行过程中的中间结果。
强烈建议这里你仔细理解理解c;体会一下这个写法。
模型初始化
我们要对定义的模型进行初始化c;主要是两个方面:权重初始化和中间变量初始化。这里<code>class="tags" href="/LLAMA2.html" title=llama2>llama2.ccode>的写法就更厉害了。请仔细欣赏下面的两个函数:
权重初始化函数:
<code class="prism language-c">class="token keyword">void class="token function">memory_map_weightsclass="token punctuation">(TransformerWeights class="token operator">*wclass="token punctuation">, Config class="token operator">*pclass="token punctuation">, class="token keyword">float class="token operator">*ptrclass="token punctuation">,class="token keyword">int shared_weightsclass="token punctuation">) class="token punctuation">{class="token keyword">int head_size class="token operator">= pclass="token operator">->dim class="token operator">/ pclass="token operator">->n_headsclass="token punctuation">;class="token comment">// make sure the multiplications below are done in 64bit to fit the parameterclass="token comment">// counts of 13B+ modelsclass="token keyword">unsigned class="token keyword">long class="token keyword">long n_layers class="token operator">= pclass="token operator">->n_layersclass="token punctuation">;wclass="token operator">->token_embedding_table class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= pclass="token operator">->vocab_size class="token operator">* pclass="token operator">->dimclass="token punctuation">;wclass="token operator">->rms_att_weight class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dimclass="token punctuation">;wclass="token operator">->wq class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dim class="token operator">* class="token punctuation">(pclass="token operator">->n_heads class="token operator">* head_sizeclass="token punctuation">)class="token punctuation">;wclass="token operator">->wk class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dim class="token operator">* class="token punctuation">(pclass="token operator">->n_kv_heads class="token operator">* head_sizeclass="token punctuation">)class="token punctuation">;wclass="token operator">->wv class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dim class="token operator">* class="token punctuation">(pclass="token operator">->n_kv_heads class="token operator">* head_sizeclass="token punctuation">)class="token punctuation">;wclass="token operator">->wo class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* class="token punctuation">(pclass="token operator">->n_heads class="token operator">* head_sizeclass="token punctuation">) class="token operator">* pclass="token operator">->dimclass="token punctuation">;wclass="token operator">->rms_ffn_weight class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dimclass="token punctuation">;wclass="token operator">->w1 class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dim class="token operator">* pclass="token operator">->hidden_dimclass="token punctuation">;wclass="token operator">->w2 class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->hidden_dim class="token operator">* pclass="token operator">->dimclass="token punctuation">;wclass="token operator">->w3 class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= n_layers class="token operator">* pclass="token operator">->dim class="token operator">* pclass="token operator">->hidden_dimclass="token punctuation">;wclass="token operator">->rms_final_weight class="token operator">= ptrclass="token punctuation">;ptr class="token operator">+= pclass="token operator">->dimclass="token punctuation">;ptr class="token operator">+= pclass="token operator">->seq_len class="token operator">* head_size class="token operator">/class="token number">2class="token punctuation">; class="token comment">// skip what used to be freq_cis_real (for RoPE)ptr class="token operator">+= pclass="token operator">->seq_len class="token operator">* head_size class="token operator">/class="token number">2class="token punctuation">; class="token comment">// skip what used to be freq_cis_imag (for RoPE)wclass="token operator">->wcls class="token operator">= shared_weights class="token operator">? wclass="token operator">->token_embedding_table class="token operator">: ptrclass="token punctuation">;
class="token punctuation">}
code>
自我感觉这个仓库很经典得一段代码就是这里了c;我没有加载权重吧c;我只是拿到了它的地址c;然后映射给我结构体中的变量。然后等我真正推理计算的时候c;用到哪一段权重就将哪一段权重加载到内存中参与计算。
中间变量初始化:
<code class="prism language-cpp">class="token keyword">void class="token function">malloc_run_stateclass="token punctuation">(RunState class="token operator">*sclass="token punctuation">, Config class="token operator">*pclass="token punctuation">) class="token punctuation">{class="token comment">// we calloc instead of malloc to keep valgrind happyclass="token keyword">int kv_dim class="token operator">= class="token punctuation">(pclass="token operator">->dim class="token operator">* pclass="token operator">->n_kv_headsclass="token punctuation">) class="token operator">/ pclass="token operator">->n_headsclass="token punctuation">;sclass="token operator">->x class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->xb class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->xb2 class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->hb class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->hidden_dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->hb2 class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->hidden_dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->q class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->key_cache class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->n_layers class="token operator">* pclass="token operator">->seq_len class="token operator">* kv_dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->value_cache class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->n_layers class="token operator">* pclass="token operator">->seq_len class="token operator">* kv_dimclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->att class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->n_heads class="token operator">* pclass="token operator">->seq_lenclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;sclass="token operator">->logits class="token operator">= class="token function">callocclass="token punctuation">(pclass="token operator">->vocab_sizeclass="token punctuation">, class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;class="token comment">// ensure all mallocs went fineclass="token keyword">if class="token punctuation">(class="token operator">!sclass="token operator">->x class="token operator">|| class="token operator">!sclass="token operator">->xb class="token operator">|| class="token operator">!sclass="token operator">->xb2 class="token operator">|| class="token operator">!sclass="token operator">->hb class="token operator">|| class="token operator">!sclass="token operator">->hb2 class="token operator">|| class="token operator">!sclass="token operator">->q class="token operator">||class="token operator">!sclass="token operator">->key_cache class="token operator">|| class="token operator">!sclass="token operator">->value_cache class="token operator">|| class="token operator">!sclass="token operator">->att class="token operator">|| class="token operator">!sclass="token operator">->logitsclass="token punctuation">) class="token punctuation">{class="token function">fprintfclass="token punctuation">(class="token constant">stderrclass="token punctuation">, class="token string">"malloc failed!\n"class="token punctuation">)class="token punctuation">;class="token function">exitclass="token punctuation">(EXIT_FAILUREclass="token punctuation">)class="token punctuation">;class="token punctuation">}
class="token punctuation">}
code>
如果不太理解权重初始化和中间变量初始化时为什么要申请那么大的空间c;可以自己手动地将网络地数据流从头到尾推一遍。
如何构建tokenzier 和 sampler
对于这两个模块地构建我们不多介绍c;感兴趣地可以自己去看看源码。
如何进行推理
这部分是我最感兴趣的地方。
<code class="prism language-c"> class="token comment">// forward all the layersclass="token keyword">for class="token punctuation">(class="token keyword">unsigned class="token keyword">long class="token keyword">long l class="token operator">= class="token number">0class="token punctuation">; l class="token operator">< pclass="token operator">->n_layersclass="token punctuation">; lclass="token operator">++class="token punctuation">) class="token punctuation">{class="token comment">// attention rmsnormclass="token function">rmsnormclass="token punctuation">(sclass="token operator">->xbclass="token punctuation">, xclass="token punctuation">, wclass="token operator">->rms_att_weight class="token operator">+ l class="token operator">* dimclass="token punctuation">, dimclass="token punctuation">)class="token punctuation">;class="token comment">// key and value point to the kv cacheclass="token keyword">int loff class="token operator">= l class="token operator">* pclass="token operator">->seq_len class="token operator">* kv_dimclass="token punctuation">; class="token comment">// kv cache layer offset for conveniencesclass="token operator">->k class="token operator">= sclass="token operator">->key_cache class="token operator">+ loff class="token operator">+ pos class="token operator">* kv_dimclass="token punctuation">;sclass="token operator">->v class="token operator">= sclass="token operator">->value_cache class="token operator">+ loff class="token operator">+ pos class="token operator">* kv_dimclass="token punctuation">;class="token comment">// qkv matmuls for this positionclass="token function">matmulclass="token punctuation">(sclass="token operator">->qclass="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->wq class="token operator">+ l class="token operator">* dim class="token operator">* dimclass="token punctuation">, dimclass="token punctuation">, dimclass="token punctuation">)class="token punctuation">;class="token function">matmulclass="token punctuation">(sclass="token operator">->kclass="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->wk class="token operator">+ l class="token operator">* dim class="token operator">* kv_dimclass="token punctuation">, dimclass="token punctuation">, kv_dimclass="token punctuation">)class="token punctuation">;class="token function">matmulclass="token punctuation">(sclass="token operator">->vclass="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->wv class="token operator">+ l class="token operator">* dim class="token operator">* kv_dimclass="token punctuation">, dimclass="token punctuation">, kv_dimclass="token punctuation">)class="token punctuation">;class="token comment">// RoPE relative positional encoding: complex-valued rotate q and k in eachclass="token comment">// headclass="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< dimclass="token punctuation">; i class="token operator">+= class="token number">2class="token punctuation">) class="token punctuation">{class="token keyword">int head_dim class="token operator">= i class="token operator">% head_sizeclass="token punctuation">;class="token keyword">float freq class="token operator">= class="token number">1.0f class="token operator">/ class="token function">powfclass="token punctuation">(class="token number">10000.0fclass="token punctuation">, head_dim class="token operator">/ class="token punctuation">(class="token keyword">floatclass="token punctuation">)head_sizeclass="token punctuation">)class="token punctuation">;class="token keyword">float val class="token operator">= pos class="token operator">* freqclass="token punctuation">;class="token keyword">float fcr class="token operator">= class="token function">cosfclass="token punctuation">(valclass="token punctuation">)class="token punctuation">;class="token keyword">float fci class="token operator">= class="token function">sinfclass="token punctuation">(valclass="token punctuation">)class="token punctuation">;class="token keyword">int rotn class="token operator">= i class="token operator">< kv_dim class="token operator">? class="token number">2 class="token operator">: class="token number">1class="token punctuation">; class="token comment">// how many vectors? 2 = q & k, 1 = q onlyclass="token keyword">for class="token punctuation">(class="token keyword">int v class="token operator">= class="token number">0class="token punctuation">; v class="token operator">< rotnclass="token punctuation">; vclass="token operator">++class="token punctuation">) class="token punctuation">{class="token keyword">float class="token operator">*vec class="token operator">=v class="token operator">== class="token number">0 class="token operator">? sclass="token operator">->q class="token operator">: sclass="token operator">->kclass="token punctuation">; class="token comment">// the vector to rotate (query or key)class="token keyword">float v0 class="token operator">= vecclass="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token keyword">float v1 class="token operator">= vecclass="token punctuation">[i class="token operator">+ class="token number">1class="token punctuation">]class="token punctuation">;vecclass="token punctuation">[iclass="token punctuation">] class="token operator">= v0 class="token operator">* fcr class="token operator">- v1 class="token operator">* fciclass="token punctuation">;vecclass="token punctuation">[i class="token operator">+ class="token number">1class="token punctuation">] class="token operator">= v0 class="token operator">* fci class="token operator">+ v1 class="token operator">* fcrclass="token punctuation">;class="token punctuation">}class="token punctuation">}class="token comment">// multihead attention. iterate over all headsclass="token keyword">int hclass="token punctuation">;
class="token macro property">class="token directive-hash">#class="token directive keyword">pragma class="token expression">omp parallel class="token keyword">for class="token function">privateclass="token punctuation">(hclass="token punctuation">)class="token keyword">for class="token punctuation">(h class="token operator">= class="token number">0class="token punctuation">; h class="token operator">< pclass="token operator">->n_headsclass="token punctuation">; hclass="token operator">++class="token punctuation">) class="token punctuation">{class="token comment">// get the query vector for this headclass="token keyword">float class="token operator">*q class="token operator">= sclass="token operator">->q class="token operator">+ h class="token operator">* head_sizeclass="token punctuation">;class="token comment">// attention scores for this headclass="token keyword">float class="token operator">*att class="token operator">= sclass="token operator">->att class="token operator">+ h class="token operator">* pclass="token operator">->seq_lenclass="token punctuation">;class="token comment">// iterate over all timesteps, including the current oneclass="token keyword">for class="token punctuation">(class="token keyword">int t class="token operator">= class="token number">0class="token punctuation">; t class="token operator"><= posclass="token punctuation">; tclass="token operator">++class="token punctuation">) class="token punctuation">{class="token comment">// get the key vector for this head and at this timestepclass="token keyword">float class="token operator">*k class="token operator">= sclass="token operator">->key_cache class="token operator">+ loff class="token operator">+ t class="token operator">* kv_dim class="token operator">+ class="token punctuation">(h class="token operator">/ kv_mulclass="token punctuation">) class="token operator">* head_sizeclass="token punctuation">;class="token comment">// calculate the attention score as the dot product of q and kclass="token keyword">float score class="token operator">= class="token number">0.0fclass="token punctuation">;class="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< head_sizeclass="token punctuation">; iclass="token operator">++class="token punctuation">) class="token punctuation">{score class="token operator">+= qclass="token punctuation">[iclass="token punctuation">] class="token operator">* kclass="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token punctuation">}score class="token operator">/= class="token function">sqrtfclass="token punctuation">(head_sizeclass="token punctuation">)class="token punctuation">;class="token comment">// save the score to the attention bufferattclass="token punctuation">[tclass="token punctuation">] class="token operator">= scoreclass="token punctuation">;class="token punctuation">}class="token comment">// softmax the scores to get attention weights, from 0..pos inclusivelyclass="token function">softmaxclass="token punctuation">(attclass="token punctuation">, pos class="token operator">+ class="token number">1class="token punctuation">)class="token punctuation">;class="token comment">// weighted sum of the values, store back into xbclass="token keyword">float class="token operator">*xb class="token operator">= sclass="token operator">->xb class="token operator">+ h class="token operator">* head_sizeclass="token punctuation">;class="token function">memsetclass="token punctuation">(xbclass="token punctuation">, class="token number">0class="token punctuation">, head_size class="token operator">* class="token keyword">sizeofclass="token punctuation">(class="token keyword">floatclass="token punctuation">)class="token punctuation">)class="token punctuation">;class="token keyword">for class="token punctuation">(class="token keyword">int t class="token operator">= class="token number">0class="token punctuation">; t class="token operator"><= posclass="token punctuation">; tclass="token operator">++class="token punctuation">) class="token punctuation">{class="token comment">// get the value vector for this head and at this timestepclass="token keyword">float class="token operator">*v class="token operator">=sclass="token operator">->value_cache class="token operator">+ loff class="token operator">+ t class="token operator">* kv_dim class="token operator">+ class="token punctuation">(h class="token operator">/ kv_mulclass="token punctuation">) class="token operator">* head_sizeclass="token punctuation">;class="token comment">// get the attention weight for this timestepclass="token keyword">float a class="token operator">= attclass="token punctuation">[tclass="token punctuation">]class="token punctuation">;class="token comment">// accumulate the weighted value into xbclass="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< head_sizeclass="token punctuation">; iclass="token operator">++class="token punctuation">) class="token punctuation">{xbclass="token punctuation">[iclass="token punctuation">] class="token operator">+= a class="token operator">* vclass="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token punctuation">}class="token punctuation">}class="token punctuation">}class="token comment">// final matmul to get the output of the attentionclass="token function">matmulclass="token punctuation">(sclass="token operator">->xb2class="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->wo class="token operator">+ l class="token operator">* dim class="token operator">* dimclass="token punctuation">, dimclass="token punctuation">, dimclass="token punctuation">)class="token punctuation">;class="token comment">// residual connection back into xclass="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< dimclass="token punctuation">; iclass="token operator">++class="token punctuation">) class="token punctuation">{xclass="token punctuation">[iclass="token punctuation">] class="token operator">+= sclass="token operator">->xb2class="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token punctuation">}class="token comment">// ffn rmsnormclass="token function">rmsnormclass="token punctuation">(sclass="token operator">->xbclass="token punctuation">, xclass="token punctuation">, wclass="token operator">->rms_ffn_weight class="token operator">+ l class="token operator">* dimclass="token punctuation">, dimclass="token punctuation">)class="token punctuation">;class="token comment">// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))class="token comment">// first calculate self.w1(x) and self.w3(x)class="token function">matmulclass="token punctuation">(sclass="token operator">->hbclass="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->w1 class="token operator">+ l class="token operator">* dim class="token operator">* hidden_dimclass="token punctuation">, dimclass="token punctuation">, hidden_dimclass="token punctuation">)class="token punctuation">;class="token function">matmulclass="token punctuation">(sclass="token operator">->hb2class="token punctuation">, sclass="token operator">->xbclass="token punctuation">, wclass="token operator">->w3 class="token operator">+ l class="token operator">* dim class="token operator">* hidden_dimclass="token punctuation">, dimclass="token punctuation">, hidden_dimclass="token punctuation">)class="token punctuation">;class="token comment">// SwiGLU non-linearityclass="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< hidden_dimclass="token punctuation">; iclass="token operator">++class="token punctuation">) class="token punctuation">{class="token keyword">float val class="token operator">= sclass="token operator">->hbclass="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token comment">// silu(x)=x*σ(x), where σ(x) is the logistic sigmoidval class="token operator">*= class="token punctuation">(class="token number">1.0f class="token operator">/ class="token punctuation">(class="token number">1.0f class="token operator">+ class="token function">expfclass="token punctuation">(class="token operator">-valclass="token punctuation">)class="token punctuation">)class="token punctuation">)class="token punctuation">;class="token comment">// elementwise multiply with w3(x)val class="token operator">*= sclass="token operator">->hb2class="token punctuation">[iclass="token punctuation">]class="token punctuation">;sclass="token operator">->hbclass="token punctuation">[iclass="token punctuation">] class="token operator">= valclass="token punctuation">;class="token punctuation">}class="token comment">// final matmul to get the output of the ffnclass="token function">matmulclass="token punctuation">(sclass="token operator">->xbclass="token punctuation">, sclass="token operator">->hbclass="token punctuation">, wclass="token operator">->w2 class="token operator">+ l class="token operator">* dim class="token operator">* hidden_dimclass="token punctuation">, hidden_dimclass="token punctuation">, dimclass="token punctuation">)class="token punctuation">;class="token comment">// residual connectionclass="token keyword">for class="token punctuation">(class="token keyword">int i class="token operator">= class="token number">0class="token punctuation">; i class="token operator">< dimclass="token punctuation">; iclass="token operator">++class="token punctuation">) class="token punctuation">{xclass="token punctuation">[iclass="token punctuation">] class="token operator">+= sclass="token operator">->xbclass="token punctuation">[iclass="token punctuation">]class="token punctuation">;class="token punctuation">}class="token punctuation">}
code>
<code>forcode>循环所有的<code>layerscode>进行推理c;有三个主要的子函数c;分别是:<code>rmsnormcode>c;<code>matmulcode>c;<code>softmaxcode>c;分别对应着三个算子c;其他的算子则是直接在<code>forcode>循环内实现的。所有的<code>layercode>都计算一遍后c;再加上后处理即可完成一个<code>tokencode>的推理。
总结
总得来说c;这个库还是有很多的东西值得我们去学习的c;学习下大神的编码思维和编码方式。