从零训练一个 TinyStories 风格 GPT 小模型
从零训练一个 TinyStories 风格 GPT 小模型

这个项目是在 Apple M4 Mac mini 16GB 上,用 MLX 从随机初始化开始训练一个 TinyStories 风格的小型 GPT 模型。它不是调用 API,也不是微调现成模型,而是把数据准备、tokenizer、模型结构、训练循环、checkpoint 和推理生成完整走了一遍。
项目地址:sergioperezcheco/llm-from-scratch
项目结果
最终训练的是一个 44M 参数量的 GPT 模型:
| 项目 | 结果 |
|---|---|
| 模型架构 | Decoder-only Transformer |
| 训练框架 | MLX |
| 数据集 | TinyStories |
| 训练步数 | 10,000 |
| 参数量 | 44,065,280 |
| 最终 loss | 1.1612 |
| 最低 loss | 1.0964 |
| 训练耗时 | 约 9 小时 15 分钟 |
| 训练设备 | Apple M4 Mac mini 16GB |
训练好的模型可以生成英文儿童短故事风格的文本。它不是通用聊天模型,但已经能证明从数据到推理的完整链路是通的。
为什么选择 TinyStories
TinyStories 的优点是语料简单、结构稳定、目标明确,很适合从零训练小模型。对 16GB 统一内存的 Mac mini 来说,直接训练通用语料模型不现实,而 TinyStories 可以把重点放在训练流程本身:
- 数据能在本地准备和编码;
- tokenizer 词表规模可控;
- 模型参数量可以压到几十 M;
- 生成结果容易人工判断是否学到了故事结构。
这类项目的价值不在于得到一个强模型,而在于理解 LLM 训练的每个环节为什么存在。
模型结构
项目里的模型是一个标准 GPT 风格的 Decoder-only Transformer,核心配置如下:
| 参数 | 值 |
|---|---|
| Transformer 层数 | 8 |
| 隐藏维度 | 512 |
| 注意力头数 | 8 |
| FFN 维度 | 2048 |
| 词表大小 | 10,000 |
| 上下文长度 | 512 |
结构上包含 token embedding、position embedding、因果自注意力、RMSNorm、SwiGLU FFN 和最终的 LM Head。整体思路接近 nanoGPT,但用 MLX 适配 Apple Silicon。
训练流程
训练脚本做了几件比较关键的事情:
- 使用
numpy.memmap读取编码后的二进制 token 数据,避免一次性吃满内存。 - 用梯度累积把小 batch 模拟成更大的有效 batch。
- 学习率使用 warmup + cosine decay。
- 每隔固定步数保存 safetensors checkpoint 和 JSON 元数据。
- 训练过程中记录 loss、速度、内存占用和 ETA。
- 支持中断后保存当前 checkpoint。
对 16GB 机器来说,内存管理比单纯堆模型更重要。项目里最终选择 MEDIUM_CONFIG,比早期设想的 100M 参数默认配置更稳,也更适合长时间跑完。
推理效果
训练完成后,可以用 generate.py 加载 checkpoints/final.safetensors 生成文本。例如:
python generate.py \
--checkpoint checkpoints/final.safetensors \
--prompt "Lily found a tiny door under the old tree. " \
--max-tokens 120 \
--temperature 0.7 \
--top-k 50 \
--n-samples 1这个模型对英文故事开头最敏感,尤其适合短句、童话式 prompt。中文、问答、代码和通用聊天都不是它的训练目标。
踩坑和取舍
这个项目里最实际的取舍是:不要一开始就追参数量。
在 Mac mini M4 16GB 上,统一内存虽然让 CPU/GPU 数据交换更方便,但总内存就是 16GB。模型参数、中间激活、优化器状态、系统和其他应用都会抢同一块内存。直接上更大的模型,训练很容易被内存压力打断。
另一个需要注意的点是 checkpoint 恢复。当前项目可以恢复模型权重和 step 元数据,但 optimizer state 没有完整恢复,所以严格意义上不是完全无损续训。如果是正式训练,这里还可以继续优化。
总结
这个项目最有价值的地方,是把“从零训练 LLM”拆成了能在个人机器上跑通的工程闭环:
- 数据准备:TinyStories 下载、切分、编码;
- 分词器:训练 10,000 词表 BPE tokenizer;
- 模型:手写 GPT Transformer;
- 训练:MLX、梯度累积、学习率调度、checkpoint;
- 推理:加载 safetensors 并采样生成。
如果目标是理解 LLM 的底层训练流程,这比直接调 API 更有意义;如果目标是得到可用的通用助手,那还是应该使用现成的大模型或微调路线。
