谷歌JAX深度學習從零開始學

谷歌JAX深度學習從零開始學

作者: 王曉華
出版社: 清華大學
出版在: 2022-06-01
ISBN-13: 9787302604365
ISBN-10: 7302604363





內容描述


JAX是一個用於高性能數值計算的Python庫,專門為深度學習領域的高性能計算而設計。本書詳解JAX框架深度學習的相關知識,配套示例源碼、PPT課件、數據集和開發環境。
本書共分為13章,內容包括JAX從零開始,一學就會的線性回歸、多層感知機與自動微分器,深度學習的理論基礎,XLA與JAX一般特性,JAX的高級特性,JAX的一些細節,JAX中的捲積,JAX與TensorFlow的比較與交互,遵循JAX函數基本規則下的自定義函數,JAX中的高級包。最後給出3個實戰案例:使用ResNet完成CIFAR100數據集分類,有趣的詞嵌入,生成對抗網絡(GAN)。
本書適合JAX框架初學者、深度學習初學者以及深度學習從業人員,也適合作為高等院校和培訓機構人工智能相關專業的師生教學參考書。


目錄大綱


目    錄
第 1 章  JAX從零開始 1
1.1  JAX來了 1
1.1.1  JAX是什麽 1
1.1.2  為什麽是JAX 2
1.2  JAX的安裝與使用 3
1.2.1  Windows Subsystem for Linux的安裝 3
1.2.2  JAX的安裝和驗證 7
1.2.3  PyCharm的下載與安裝 8
1.2.4  使用PyCharm和JAX 9
1.2.5  JAX的Python代碼小練習:計算SeLU函數 11
1.3  JAX實戰—MNIST手寫體的識別 12
1.3.1  第一步:準備數據集 12
1.3.2  第二步:模型的設計 13
1.3.3  第三步:模型的訓練 13
1.4  本章小結 15
第2章  一學就會的線性回歸、多層感知機與自動微分器 16
2.1  多層感知機 16
2.1.1  全連接層—多層感知機的隱藏層 16
2.1.2  使用JAX實現一個全連接層 17
2.1.3  更多功能的全連接函數 19
2.2  JAX實戰—鳶尾花分類 22
2.2.1  鳶尾花數據準備與分析 23
2.2.2  模型分析—採用線性回歸實戰鳶尾花分類 24
2.2.3  基於JAX的線性回歸模型的編寫 25
2.2.4  多層感知機與神經網絡 27
2.2.5  基於JAX的激活函數、softmax函數與交叉熵函數 29
2.2.6  基於多層感知機的鳶尾花分類實戰 31
2.3  自動微分器 35
2.3.1  什麽是微分器 36
2.3.2  JAX中的自動微分 37
2.4  本章小結 38
第3章  深度學習的理論基礎 39
3.1  BP神經網絡簡介 39
3.2  BP神經網絡兩個基礎算法詳解 42
3.2.1  最小二乘法詳解 43
3.2.2  道士下山的故事—梯度下降算法 44
3.2.3  最小二乘法的梯度下降算法以及JAX實現 46
3.3  反饋神經網絡反向傳播算法介紹 52
3.3.1  深度學習基礎 52
3.3.2  鏈式求導法則 53
3.3.3  反饋神經網絡原理與公式推導 54
3.3.4  反饋神經網絡原理的激活函數 59
3.3.5  反饋神經網絡原理的Python實現 60
3.4  本章小結 64
第4章  XLA與JAX一般特性 65
4.1  JAX與XLA 65
4.1.1  XLA如何運行 65
4.1.2  XLA如何工作 67
4.2  JAX一般特性 67
4.2.1  利用JIT加快程序運行 67
4.2.2  自動微分器—grad函數 68
4.2.3  自動向量化映射—vmap函數 70
4.3  本章小結 71
第5章  JAX的高級特性 72
5.1  JAX與NumPy 72
5.1.1  像NumPy一樣運行的JAX 72
5.1.2  JAX的底層實現lax 74
5.1.3  並行化的JIT機制與不適合使用JIT的情景 75
5.1.4  JIT的參數詳解 77
5.2  JAX程序的編寫規範要求 78
5.2.1  JAX函數必須要為純函數 79
5.2.2  JAX中數組的規範操作 80
5.2.3  JIT中的控制分支 83
5.2.4  JAX中的if、while、for、scan函數 85
5.3  本章小結 89
第6章  JAX的一些細節 90
6.1  JAX中的數值計算 90
6.1.1  JAX中的grad函數使用細節 90
6.1.2  不要編寫帶有副作用的代碼—JAX與NumPy的差異 93
6.1.3  一個簡單的線性回歸方程擬合 94
6.2  JAX中的性能提高 98
6.2.1  JIT的轉換過程 98
6.2.2  JIT無法對非確定參數追蹤 100
6.2.3  理解JAX中的預編譯與緩存 102
6.3  JAX中的函數自動打包器—vmap 102
6.3.1  剝洋蔥—對數據的手工打包 102
6.3.2  剝甘藍—JAX中的自動向量化函數vmap 104
6.3.3  JAX中高階導數的處理 105
6.4  JAX中的結構體保存方法Pytrees 106
6.4.1  Pytrees是什麽 106
6.4.2  常見的pytree函數 107
6.4.3  深度學習模型參數的控制(線性模型) 108
6.4.4  深度學習模型參數的控制(非線性模型) 113
6.4.5  自定義的Pytree節點 113
6.4.6  JAX數值計算的運行機制 115
6.5  本章小結 117
第7章  JAX中的捲積 118
7.1  什麽是捲積 118
7.1.1  捲積運算 119
7.1.2  JAX中的一維捲積與多維捲積的計算 120
7.1.3  JAX.lax中的一般捲積的計算與表示 122
7.2  JAX實戰—基於VGG架構的MNIST數據集分類 124
7.2.1  深度學習Visual Geometry Group(VGG)架構 124
7.2.2  VGG中使用的組件介紹與實現 126
7.2.3  基於VGG6的MNIST數據集分類實戰 129
7.3  本章小結 133
第8章  JAX與TensorFlow的比較與交互 134
8.1  基於TensorFlow的MNIST分類 134
8.2  TensorFlow與JAX的交互 137
8.2.1  基於JAX的TensorFlow Datasets數據集分類實戰 137
8.2.2  TensorFlow Datasets數據集庫簡介 141
8.3  本章小結 145
第9章  遵循JAX函數基本規則下的自定義函數 146
9.1  JAX函數的基本規則 146
9.1.1  使用已有的原語 146
9.1.2  自定義的JVP以及反向VJP 147
9.1.3  進階jax.custom_jvp和jax.custom_vjp函數用法 150
9.2  Jaxpr解釋器的使用 153
9.2.1  Jaxpr tracer 153
9.2.2  自定義的可以被Jaxpr跟蹤的函數 155
9.3  JAX維度名稱的使用 157
9.3.1  JAX的維度名稱 157
9.3.2  自定義JAX中的向量Tensor 158
9.4  本章小結 159
第10章  JAX中的高級包 160
10.1  JAX中的包 160
10.1.1  jax.numpy的使用 161
10.1.2  jax.nn的使用 162
10.2  jax.experimental包和jax.example_libraries的使用 163
10.2.1  jax.experimental.sparse的使用 163
10.2.2  jax.experimental.optimizers模塊的使用 166
10.2.3  jax.experimental.stax的使用 168
10.3  本章小結 168
第11章  JAX實戰——使用ResNet完成CIFAR100數據集分類 169
11.1  ResNet基礎原理與程序設計基礎 169
11.1.1  ResNet誕生的背景 170
11.1.2  使用JAX中實現的部件—不要重復造輪子 173
11.1.3  一些stax模塊中特有的類 175
11.2  ResNet實戰—CIFAR100數據集分類 176
11.2.1  CIFAR100數據集簡介 176
11.2.2  ResNet殘差模塊的實現 179
11.2.3  ResNet網絡的實現 181
11.2.4  使用ResNet對CIFAR100數據集進行分類 182
11.3  本章小結 184
第12章  JAX實戰—有趣的詞嵌入 185
12.1  文本數據處理 185
12.1.1  數據集和數據清洗 185
12.1.2  停用詞的使用 188
12.1.3  詞向量訓練模型word2vec的使用 190
12.1.4  文本主題的提取:基於TF-IDF 193
12.1.5  文本主題的提取:基於TextRank 197
12.2  更多的詞嵌入方法—FastText和預訓練詞向量 200
12.2.1  FastText的原理與基礎算法 201
12.2.2  FastText訓練以及與JAX的協同使用 202
12.2.3  使用其他預訓練參數嵌入矩陣(中文) 204
12.3  針對文本的捲積神經網絡模型—字符捲積 205
12.3.1  字符(非單詞)文本的處理 206
12.3.2  捲積神經網絡文本分類模型的實現—conv1d(一維捲積) 213
12.4  針對文本的捲積神經網絡模型—詞捲積 216
12.4.1  單詞的文本處理 216
12.4.2  捲積神經網絡文本分類模型的實現 218
12.5  使用捲積對文本分類的補充內容 219
12.5.1  中文的文本處理 219
12.5.2  其他細節 222
12.6  本章小結 222
第13章  JAX實戰—生成對抗網絡(GAN) 223
13.1  GAN的工作原理詳解 223
13.1.1  生成器與判別器共同構成了一個GAN 224
13.1.2  GAN是怎麽工作的 225
13.2  GAN的數學原理詳解 225
13.2.1  GAN的損失函數 226
13.2.2  生成器的產生分佈的數學原理—相對熵簡介 226
13.3  JAX實戰—GAN網絡 227
13.3.1  生成對抗網絡GAN的實現 228
13.3.2  GAN的應用前景 232
13.4  本章小結 235
附錄  Windows 11安裝GPU版本的JAX 236




相關書籍

必學!Python 資料科學‧機器學習最強套件 - NumPy、Pandas、Matplotlib、OpenCV、scikit-learn、tf.Keras

作者 株式会社アイデミー 石川 聡彦 劉金讓 譯;施威銘研究室 監修

2022-06-01

Predictive Control for Linear and Hybrid Systems

作者 Francesco Borrelli Alberto Bemporad Manfred Morari

2022-06-01

Applied Computational Thinking with Python: Design algorithmic solutions for complex and challenging real-world problems

作者 Jesus Sofía de Martinez Dayrene

2022-06-01