{"id":1279,"date":"2023-08-13T05:48:31","date_gmt":"2023-08-13T05:48:31","guid":{"rendered":"https:\/\/www.gptmain.news\/?p=1279"},"modified":"2023-08-13T05:48:31","modified_gmt":"2023-08-13T05:48:31","slug":"%d1%81%d0%be%d0%b7%d0%b4%d0%b0%d0%b9%d1%82%d0%b5-transformer-%d0%b2-jax-%d1%81-%d0%bd%d1%83%d0%bb%d1%8f-%d0%ba%d0%b0%d0%ba-%d0%bd%d0%b0%d0%bf%d0%b8%d1%81%d0%b0%d1%82%d1%8c-%d0%b8-%d0%be%d0%b1%d1%83","status":"publish","type":"post","link":"https:\/\/gptmain.news\/?p=1279","title":{"rendered":"\u0421\u043e\u0437\u0434\u0430\u0439\u0442\u0435 Transformer \u0432 JAX \u0441 \u043d\u0443\u043b\u044f: \u043a\u0430\u043a \u043d\u0430\u043f\u0438\u0441\u0430\u0442\u044c \u0438 \u043e\u0431\u0443\u0447\u0438\u0442\u044c \u0441\u0432\u043e\u0438 \u0441\u043e\u0431\u0441\u0442\u0432\u0435\u043d\u043d\u044b\u0435 \u043c\u043e\u0434\u0435\u043b\u0438\n | GPTMain News"},"content":{"rendered":"<div id=\"\">\n<p>\u0412 \u044d\u0442\u043e\u043c \u0440\u0443\u043a\u043e\u0432\u043e\u0434\u0441\u0442\u0432\u0435 \u043c\u044b \u0440\u0430\u0441\u0441\u043c\u043e\u0442\u0440\u0438\u043c, \u043a\u0430\u043a \u0440\u0430\u0437\u0440\u0430\u0431\u043e\u0442\u0430\u0442\u044c \u043d\u0435\u0439\u0440\u043e\u043d\u043d\u0443\u044e \u0441\u0435\u0442\u044c (NN) \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e JAX.  \u0418 \u043a\u0430\u043a\u0443\u044e \u043b\u0443\u0447\u0448\u0435 \u043c\u043e\u0434\u0435\u043b\u044c \u0432\u044b\u0431\u0440\u0430\u0442\u044c, \u0447\u0435\u043c \u0422\u0440\u0430\u043d\u0441\u0444\u043e\u0440\u043c\u0435\u0440.  \u041f\u043e \u043c\u0435\u0440\u0435 \u0440\u043e\u0441\u0442\u0430 \u043f\u043e\u043f\u0443\u043b\u044f\u0440\u043d\u043e\u0441\u0442\u0438 JAX \u0432\u0441\u0435 \u0431\u043e\u043b\u044c\u0448\u0435 \u0438 \u0431\u043e\u043b\u044c\u0448\u0435 \u043a\u043e\u043c\u0430\u043d\u0434 \u0440\u0430\u0437\u0440\u0430\u0431\u043e\u0442\u0447\u0438\u043a\u043e\u0432 \u043d\u0430\u0447\u0438\u043d\u0430\u044e\u0442 \u044d\u043a\u0441\u043f\u0435\u0440\u0438\u043c\u0435\u043d\u0442\u0438\u0440\u043e\u0432\u0430\u0442\u044c \u0441 \u043d\u0438\u043c \u0438 \u0432\u043a\u043b\u044e\u0447\u0430\u0442\u044c \u0435\u0433\u043e \u0432 \u0441\u0432\u043e\u0438 \u043f\u0440\u043e\u0435\u043a\u0442\u044b.  \u041d\u0435\u0441\u043c\u043e\u0442\u0440\u044f \u043d\u0430 \u0442\u043e, \u0447\u0442\u043e \u0435\u043c\u0443 \u043d\u0435 \u0445\u0432\u0430\u0442\u0430\u0435\u0442 \u0437\u0440\u0435\u043b\u043e\u0441\u0442\u0438 Tensorflow \u0438\u043b\u0438 Pytorch, \u043e\u043d \u043f\u0440\u0435\u0434\u043e\u0441\u0442\u0430\u0432\u043b\u044f\u0435\u0442 \u043d\u0435\u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u043e\u0442\u043b\u0438\u0447\u043d\u044b\u0435 \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u0434\u043b\u044f \u0441\u043e\u0437\u0434\u0430\u043d\u0438\u044f \u0438 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f \u043c\u043e\u0434\u0435\u043b\u0435\u0439 \u0433\u043b\u0443\u0431\u043e\u043a\u043e\u0433\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f.<\/p>\n<p>\u0427\u0442\u043e\u0431\u044b \u043f\u043e\u043b\u0443\u0447\u0438\u0442\u044c \u043f\u043e\u043b\u043d\u043e\u0435 \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u0438\u0435 \u043e\u0431 \u043e\u0441\u043d\u043e\u0432\u0430\u0445 JAX, \u043f\u0440\u043e\u0447\u0442\u0438\u0442\u0435 \u043c\u043e\u044e \u043f\u0440\u0435\u0434\u044b\u0434\u0443\u0449\u0443\u044e \u0441\u0442\u0430\u0442\u044c\u044e, \u0435\u0441\u043b\u0438 \u0432\u044b \u0435\u0449\u0435 \u044d\u0442\u043e\u0433\u043e \u043d\u0435 \u0441\u0434\u0435\u043b\u0430\u043b\u0438.  \u0422\u0430\u043a\u0436\u0435 \u0432\u044b \u043c\u043e\u0436\u0435\u0442\u0435 \u043d\u0430\u0439\u0442\u0438 \u043f\u043e\u043b\u043d\u044b\u0439 \u043a\u043e\u0434 \u0432 \u043d\u0430\u0448\u0435\u043c \u0440\u0435\u043f\u043e\u0437\u0438\u0442\u043e\u0440\u0438\u0438 Github.<\/p>\n<p>\u041e\u0434\u043d\u0430 \u0438\u0437 \u0440\u0430\u0441\u043f\u0440\u043e\u0441\u0442\u0440\u0430\u043d\u0435\u043d\u043d\u044b\u0445 \u043f\u0440\u043e\u0431\u043b\u0435\u043c, \u0441 \u043a\u043e\u0442\u043e\u0440\u043e\u0439 \u0441\u0442\u0430\u043b\u043a\u0438\u0432\u0430\u044e\u0442\u0441\u044f \u043b\u044e\u0434\u0438, \u043d\u0430\u0447\u0438\u043d\u0430\u044e\u0449\u0438\u0435 \u0440\u0430\u0431\u043e\u0442\u0430\u0442\u044c \u0441 JAX, \u2014 \u044d\u0442\u043e \u0432\u044b\u0431\u043e\u0440 \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a\u0430.  \u041f\u043e\u0445\u043e\u0436\u0435, \u0447\u0442\u043e \u043b\u044e\u0434\u0438 \u0432 Deepmind \u043e\u0447\u0435\u043d\u044c \u0437\u0430\u043d\u044f\u0442\u044b \u0438 \u0443\u0436\u0435 \u0432\u044b\u043f\u0443\u0441\u0442\u0438\u043b\u0438 \u043c\u043d\u043e\u0436\u0435\u0441\u0442\u0432\u043e \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a\u043e\u0432 \u043f\u043e\u0432\u0435\u0440\u0445 JAX.  \u0412\u043e\u0442 \u0441\u043f\u0438\u0441\u043e\u043a \u0441\u0430\u043c\u044b\u0445 \u0438\u0437\u0432\u0435\u0441\u0442\u043d\u044b\u0445 \u0438\u0437 \u043d\u0438\u0445:<\/p>\n<ul>\n<li>\n<p>Haiku: Haiku \u2014 \u044d\u0442\u043e \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a \u0434\u043b\u044f \u0433\u043b\u0443\u0431\u043e\u043a\u043e\u0433\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u0435\u0442\u0441\u044f \u043c\u043d\u043e\u0433\u0438\u043c\u0438 \u0432\u043d\u0443\u0442\u0440\u0435\u043d\u043d\u0438\u043c\u0438 \u043a\u043e\u043c\u0430\u043d\u0434\u0430\u043c\u0438 Google \u0438 Deepmind.  \u041e\u043d \u043f\u0440\u0435\u0434\u043e\u0441\u0442\u0430\u0432\u043b\u044f\u0435\u0442 \u043d\u0435\u0441\u043a\u043e\u043b\u044c\u043a\u043e \u043f\u0440\u043e\u0441\u0442\u044b\u0445 \u043a\u043e\u043c\u043f\u043e\u043d\u0443\u0435\u043c\u044b\u0445 \u0430\u0431\u0441\u0442\u0440\u0430\u043a\u0446\u0438\u0439 \u0434\u043b\u044f \u0438\u0441\u0441\u043b\u0435\u0434\u043e\u0432\u0430\u043d\u0438\u0439 \u0432 \u043e\u0431\u043b\u0430\u0441\u0442\u0438 \u043c\u0430\u0448\u0438\u043d\u043d\u043e\u0433\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f, \u0430 \u0442\u0430\u043a\u0436\u0435 \u0433\u043e\u0442\u043e\u0432\u044b\u0435 \u043a \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u043d\u0438\u044e \u043c\u043e\u0434\u0443\u043b\u0438 \u0438 \u0441\u043b\u043e\u0438.<\/p>\n<\/li>\n<li>\n<p>Optax: Optax \u2014 \u044d\u0442\u043e \u0431\u0438\u0431\u043b\u0438\u043e\u0442\u0435\u043a\u0430 \u043e\u0431\u0440\u0430\u0431\u043e\u0442\u043a\u0438 \u0438 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0446\u0438\u0438 \u0433\u0440\u0430\u0434\u0438\u0435\u043d\u0442\u043e\u0432, \u043a\u043e\u0442\u043e\u0440\u0430\u044f \u0441\u043e\u0434\u0435\u0440\u0436\u0438\u0442 \u0433\u043e\u0442\u043e\u0432\u044b\u0435 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0442\u043e\u0440\u044b \u0438 \u0441\u0432\u044f\u0437\u0430\u043d\u043d\u044b\u0435 \u043c\u0430\u0442\u0435\u043c\u0430\u0442\u0438\u0447\u0435\u0441\u043a\u0438\u0435 \u043e\u043f\u0435\u0440\u0430\u0446\u0438\u0438.<\/p>\n<\/li>\n<li>\n<p>RLax: RLax \u2014 \u044d\u0442\u043e \u0441\u0438\u0441\u0442\u0435\u043c\u0430 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f \u0441 \u043f\u043e\u0434\u043a\u0440\u0435\u043f\u043b\u0435\u043d\u0438\u0435\u043c \u0441\u043e \u043c\u043d\u043e\u0436\u0435\u0441\u0442\u0432\u043e\u043c \u043f\u043e\u0434\u043a\u043e\u043c\u043f\u043e\u043d\u0435\u043d\u0442\u043e\u0432 \u0438 \u043e\u043f\u0435\u0440\u0430\u0446\u0438\u0439 RL.<\/p>\n<\/li>\n<li>\n<p>Chex: Chex \u2014 \u044d\u0442\u043e \u0431\u0438\u0431\u043b\u0438\u043e\u0442\u0435\u043a\u0430 \u0443\u0442\u0438\u043b\u0438\u0442 \u0434\u043b\u044f \u0442\u0435\u0441\u0442\u0438\u0440\u043e\u0432\u0430\u043d\u0438\u044f \u0438 \u043e\u0442\u043b\u0430\u0434\u043a\u0438 \u043a\u043e\u0434\u0430 JAX.<\/p>\n<\/li>\n<li>\n<p>Jraph: Jraph \u2014 \u044d\u0442\u043e \u0431\u0438\u0431\u043b\u0438\u043e\u0442\u0435\u043a\u0430 \u0433\u0440\u0430\u0444\u043e\u0432\u044b\u0445 \u043d\u0435\u0439\u0440\u043e\u043d\u043d\u044b\u0445 \u0441\u0435\u0442\u0435\u0439 \u0432 JAX.<\/p>\n<\/li>\n<li>\n<p>Flax: Flax \u2014 \u0435\u0449\u0435 \u043e\u0434\u043d\u0430 \u0431\u0438\u0431\u043b\u0438\u043e\u0442\u0435\u043a\u0430 \u043d\u0435\u0439\u0440\u043e\u043d\u043d\u044b\u0445 \u0441\u0435\u0442\u0435\u0439 \u0441 \u043c\u043d\u043e\u0436\u0435\u0441\u0442\u0432\u043e\u043c \u0433\u043e\u0442\u043e\u0432\u044b\u0445 \u043a \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u043d\u0438\u044e \u043c\u043e\u0434\u0443\u043b\u0435\u0439, \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0442\u043e\u0440\u043e\u0432 \u0438 \u0443\u0442\u0438\u043b\u0438\u0442.  \u042d\u0442\u043e, \u0441\u043a\u043e\u0440\u0435\u0435 \u0432\u0441\u0435\u0433\u043e, \u0441\u0430\u043c\u043e\u0435 \u0431\u043b\u0438\u0437\u043a\u043e\u0435, \u0447\u0442\u043e \u0443 \u043d\u0430\u0441 \u0435\u0441\u0442\u044c \u0432\u043e \u0432\u0441\u0435\u0439 JAX-\u0438\u043d\u0444\u0440\u0430\u0441\u0442\u0440\u0443\u043a\u0442\u0443\u0440\u0435.<\/p>\n<\/li>\n<li>\n<p>Objax: Objax \u2014 \u044d\u0442\u043e \u0442\u0440\u0435\u0442\u044c\u044f \u0431\u0438\u0431\u043b\u0438\u043e\u0442\u0435\u043a\u0430 \u043c\u043b, \u043e\u0440\u0438\u0435\u043d\u0442\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u0430\u044f \u043d\u0430 \u043e\u0431\u044a\u0435\u043a\u0442\u043d\u043e-\u043e\u0440\u0438\u0435\u043d\u0442\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u043e\u0435 \u043f\u0440\u043e\u0433\u0440\u0430\u043c\u043c\u0438\u0440\u043e\u0432\u0430\u043d\u0438\u0435 \u0438 \u0447\u0438\u0442\u0430\u0431\u0435\u043b\u044c\u043d\u043e\u0441\u0442\u044c \u043a\u043e\u0434\u0430.  \u041e\u043f\u044f\u0442\u044c \u0436\u0435, \u043e\u043d \u0441\u043e\u0434\u0435\u0440\u0436\u0438\u0442 \u0441\u0430\u043c\u044b\u0435 \u043f\u043e\u043f\u0443\u043b\u044f\u0440\u043d\u044b\u0435 \u043c\u043e\u0434\u0443\u043b\u0438, \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u0430\u043a\u0442\u0438\u0432\u0430\u0446\u0438\u0438, \u043f\u043e\u0442\u0435\u0440\u0438, \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0442\u043e\u0440\u044b, \u0430 \u0442\u0430\u043a\u0436\u0435 \u043d\u0435\u0441\u043a\u043e\u043b\u044c\u043a\u043e <strong>\u043f\u0440\u0435\u0434\u0432\u0430\u0440\u0438\u0442\u0435\u043b\u044c\u043d\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u043d\u044b\u0435 \u043c\u043e\u0434\u0435\u043b\u0438<\/strong>.<\/p>\n<\/li>\n<li>\n<p>Trax: Trax \u2014 \u044d\u0442\u043e \u0441\u043a\u0432\u043e\u0437\u043d\u0430\u044f \u0431\u0438\u0431\u043b\u0438\u043e\u0442\u0435\u043a\u0430 \u0434\u043b\u044f \u0433\u043b\u0443\u0431\u043e\u043a\u043e\u0433\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f, \u043e\u0440\u0438\u0435\u043d\u0442\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u0430\u044f \u043d\u0430 \u0442\u0440\u0430\u043d\u0441\u0444\u043e\u0440\u043c\u0435\u0440\u043e\u0432.<\/p>\n<\/li>\n<li>\n<p>JAXline: JAXline \u2014 \u044d\u0442\u043e \u0431\u0438\u0431\u043b\u0438\u043e\u0442\u0435\u043a\u0430 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f \u0441 \u0443\u0447\u0438\u0442\u0435\u043b\u0435\u043c, \u043a\u043e\u0442\u043e\u0440\u0430\u044f \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u0435\u0442\u0441\u044f \u0434\u043b\u044f <strong>\u0440\u0430\u0441\u043f\u0440\u0435\u0434\u0435\u043b\u0435\u043d\u043d\u043e\u0435 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u0435 JAX<\/strong> \u0438 \u043e\u0446\u0435\u043d\u043a\u0430.<\/p>\n<\/li>\n<li>\n<p>ACME: ACME \u2014 \u0435\u0449\u0435 \u043e\u0434\u043d\u0430 \u0438\u0441\u0441\u043b\u0435\u0434\u043e\u0432\u0430\u0442\u0435\u043b\u044c\u0441\u043a\u0430\u044f \u0441\u0442\u0440\u0443\u043a\u0442\u0443\u0440\u0430 \u0434\u043b\u044f \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f \u0441 \u043f\u043e\u0434\u043a\u0440\u0435\u043f\u043b\u0435\u043d\u0438\u0435\u043c.<\/p>\n<\/li>\n<li>\n<p>JAX-MD: JAX-MD \u2014 \u044d\u0442\u043e \u043d\u0438\u0448\u0435\u0432\u0430\u044f \u0441\u0442\u0440\u0443\u043a\u0442\u0443\u0440\u0430, \u043a\u043e\u0442\u043e\u0440\u0430\u044f \u0437\u0430\u043d\u0438\u043c\u0430\u0435\u0442\u0441\u044f \u043c\u043e\u043b\u0435\u043a\u0443\u043b\u044f\u0440\u043d\u043e\u0439 \u0434\u0438\u043d\u0430\u043c\u0438\u043a\u043e\u0439.<\/p>\n<\/li>\n<li>\n<p>Jaxchem: JAXChem \u2014 \u0435\u0449\u0435 \u043e\u0434\u043d\u0430 \u043d\u0438\u0448\u0435\u0432\u0430\u044f \u0431\u0438\u0431\u043b\u0438\u043e\u0442\u0435\u043a\u0430, \u0432 \u043a\u043e\u0442\u043e\u0440\u043e\u0439 \u043e\u0441\u043e\u0431\u043e\u0435 \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435 \u0443\u0434\u0435\u043b\u044f\u0435\u0442\u0441\u044f \u0445\u0438\u043c\u0438\u0447\u0435\u0441\u043a\u043e\u043c\u0443 \u043c\u043e\u0434\u0435\u043b\u0438\u0440\u043e\u0432\u0430\u043d\u0438\u044e.<\/p>\n<\/li>\n<\/ul>\n<p>\u041a\u043e\u043d\u0435\u0447\u043d\u043e, \u0432\u043e\u043f\u0440\u043e\u0441 \u0432 \u0442\u043e\u043c, \u0447\u0442\u043e \u043c\u043d\u0435 \u0432\u044b\u0431\u0440\u0430\u0442\u044c?<\/p>\n<p>\u0427\u0435\u0441\u0442\u043d\u043e \u0433\u043e\u0432\u043e\u0440\u044f, \u044f \u043d\u0435 \u0443\u0432\u0435\u0440\u0435\u043d.<\/p>\n<p>\u041d\u043e \u0435\u0441\u043b\u0438 \u0431\u044b \u044f \u0431\u044b\u043b \u043d\u0430 \u0432\u0430\u0448\u0435\u043c \u043c\u0435\u0441\u0442\u0435 \u0438 \u0445\u043e\u0442\u0435\u043b \u0431\u044b \u0438\u0437\u0443\u0447\u0438\u0442\u044c JAX, \u044f \u0431\u044b \u043d\u0430\u0447\u0430\u043b \u0441 \u0441\u0430\u043c\u044b\u0445 \u043f\u043e\u043f\u0443\u043b\u044f\u0440\u043d\u044b\u0445.  Haiku \u0438 Flax \u0447\u0430\u0441\u0442\u043e \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u044e\u0442\u0441\u044f \u0432 Google\/Deepmind \u0438 \u0438\u043c\u0435\u044e\u0442 \u0441\u0430\u043c\u043e\u0435 \u0430\u043a\u0442\u0438\u0432\u043d\u043e\u0435 \u0441\u043e\u043e\u0431\u0449\u0435\u0441\u0442\u0432\u043e \u043d\u0430 Github.  \u0412 \u044d\u0442\u043e\u0439 \u0441\u0442\u0430\u0442\u044c\u0435 \u044f \u043d\u0430\u0447\u043d\u0443 \u0441 \u043f\u0435\u0440\u0432\u043e\u0433\u043e \u0438 \u043f\u043e\u0441\u043c\u043e\u0442\u0440\u044e, \u043f\u043e\u043d\u0430\u0434\u043e\u0431\u0438\u0442\u0441\u044f \u043b\u0438 \u043c\u043d\u0435 \u0435\u0449\u0435 \u043e\u0434\u0438\u043d \u0432 \u0431\u0443\u0434\u0443\u0449\u0435\u043c.<\/p>\n<p>\u0418\u0442\u0430\u043a, \u0432\u044b \u0433\u043e\u0442\u043e\u0432\u044b \u043f\u043e\u0441\u0442\u0440\u043e\u0438\u0442\u044c Transformer \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e JAX \u0438 Haiku?  \u041a\u0441\u0442\u0430\u0442\u0438, \u044f \u043f\u0440\u0435\u0434\u043f\u043e\u043b\u0430\u0433\u0430\u044e, \u0447\u0442\u043e \u0432\u044b \u0445\u043e\u0440\u043e\u0448\u043e \u0440\u0430\u0437\u0431\u0438\u0440\u0430\u0435\u0442\u0435\u0441\u044c \u0432 \u0442\u0440\u0430\u043d\u0441\u0444\u043e\u0440\u043c\u0435\u0440\u0430\u0445.  \u0415\u0441\u043b\u0438 \u043d\u0435\u0442, \u043f\u043e\u0441\u043e\u0432\u0435\u0442\u0443\u0439\u0442\u0435 \u043d\u0430\u0448\u0438 \u0441\u0442\u0430\u0442\u044c\u0438 \u043e \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0438 \u0438 \u0442\u0440\u0430\u043d\u0441\u0444\u043e\u0440\u043c\u0435\u0440\u0430\u0445.<\/p>\n<p>\u041d\u0430\u0447\u043d\u0435\u043c \u0441 \u0431\u043b\u043e\u043a\u0430 \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u044f \u043a \u0441\u0435\u0431\u0435.<\/p>\n<h2 id=\"the-self-attention-block\">\u0411\u043b\u043e\u043a \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u044f \u043a \u0441\u0435\u0431\u0435<\/h2>\n<p>\u0412\u043e-\u043f\u0435\u0440\u0432\u044b\u0445, \u043d\u0430\u043c \u043d\u0443\u0436\u043d\u043e \u0438\u043c\u043f\u043e\u0440\u0442\u0438\u0440\u043e\u0432\u0430\u0442\u044c JAX \u0438 Haiku.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> jax<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">numpy <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">as<\/span><span class=\"token plain\"> jnp<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> haiku <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">as<\/span><span class=\"token plain\"> hk<\/span><\/p><p><span class=\"token plain\">Import numpy <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">as<\/span><span class=\"token plain\"> np<\/span><\/p><\/pre>\n<p>\u041a \u0441\u0447\u0430\u0441\u0442\u044c\u044e \u0434\u043b\u044f \u043d\u0430\u0441, \u0432 Haiku \u0435\u0441\u0442\u044c \u0432\u0441\u0442\u0440\u043e\u0435\u043d\u043d\u044b\u0439 <code>MultiHeadAttention<\/code> \u0431\u043b\u043e\u043a, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u043c\u043e\u0436\u043d\u043e \u0440\u0430\u0441\u0448\u0438\u0440\u0438\u0442\u044c, \u0447\u0442\u043e\u0431\u044b \u043f\u043e\u0441\u0442\u0440\u043e\u0438\u0442\u044c \u0437\u0430\u043c\u0430\u0441\u043a\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u044b\u0439 \u0431\u043b\u043e\u043a \u0432\u043d\u0443\u0442\u0440\u0435\u043d\u043d\u0435\u0433\u043e \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u044f.  \u041d\u0430\u0448 \u0431\u043b\u043e\u043a \u043f\u0440\u0438\u043d\u0438\u043c\u0430\u0435\u0442 \u0437\u0430\u043f\u0440\u043e\u0441, \u043a\u043b\u044e\u0447, \u0437\u043d\u0430\u0447\u0435\u043d\u0438\u0435, \u0430 \u0442\u0430\u043a\u0436\u0435 \u043c\u0430\u0441\u043a\u0443 \u0438 \u0432\u043e\u0437\u0432\u0440\u0430\u0449\u0430\u0435\u0442 \u0440\u0435\u0437\u0443\u043b\u044c\u0442\u0430\u0442 \u0432 \u0432\u0438\u0434\u0435 \u043c\u0430\u0441\u0441\u0438\u0432\u0430 JAX.  \u0412\u044b \u043c\u043e\u0436\u0435\u0442\u0435 \u0432\u0438\u0434\u0435\u0442\u044c, \u0447\u0442\u043e \u043a\u043e\u0434 \u043e\u0447\u0435\u043d\u044c \u043f\u043e\u0445\u043e\u0436 \u043d\u0430 \u0441\u0442\u0430\u043d\u0434\u0430\u0440\u0442\u043d\u044b\u0439 \u043a\u043e\u0434 Pytorch \u0438\u043b\u0438 Tensorflow.  \u0412\u0441\u0435, \u0447\u0442\u043e \u043c\u044b \u0434\u0435\u043b\u0430\u0435\u043c, \u044d\u0442\u043e \u0441\u0442\u0440\u043e\u0438\u043c \u043a\u0430\u0443\u0437\u0430\u043b\u044c\u043d\u0443\u044e \u043c\u0430\u0441\u043a\u0443, \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u044f <code>np.trill()<\/code>\u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u043e\u0431\u043d\u0443\u043b\u044f\u044e\u0442 \u0432\u0441\u0435 \u044d\u043b\u0435\u043c\u0435\u043d\u0442\u044b \u043c\u0430\u0441\u0441\u0438\u0432\u0430 \u0432\u044b\u0448\u0435 k-\u0433\u043e, \u0443\u043c\u043d\u043e\u0436\u0430\u044e\u0442 \u043d\u0430 \u043d\u0430\u0448\u0443 \u043c\u0430\u0441\u043a\u0443 \u0438 \u043f\u0435\u0440\u0435\u0434\u0430\u044e\u0442 \u0432\u0441\u0435 \u0432 <code>hk.MultiHeadAttention<\/code> \u043c\u043e\u0434\u0443\u043b\u044c.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">SelfAttention<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">MultiHeadAttention<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"Self attention with a causal mask applied.\"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__call__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            query<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            key<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Optional<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token plain\">jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            value<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Optional<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token plain\">jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            mask<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Optional<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token plain\">jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token operator\">&gt;<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        key <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> key <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">if<\/span><span class=\"token plain\"> key <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">is<\/span><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">not<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">else<\/span><span class=\"token plain\"> query<\/span><\/p><p><span class=\"token plain\">        value <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> value <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">if<\/span><span class=\"token plain\"> value <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">is<\/span><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">not<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">else<\/span><span class=\"token plain\"> query<\/span><\/p><p><span class=\"token plain\">        seq_len <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> query<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">shape<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        causal_mask <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> np<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">tril<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">np<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ones<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">seq_len<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> seq_len<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        mask <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> mask <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> causal_mask <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">if<\/span><span class=\"token plain\"> mask <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">is<\/span><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">not<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">else<\/span><span class=\"token plain\"> causal_mask<\/span><\/p><p><span class=\"token plain\">        <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">super<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">__call__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">query<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> key<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> value<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mask<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<p>\u042d\u0442\u043e\u0442 \u0444\u0440\u0430\u0433\u043c\u0435\u043d\u0442 \u043f\u043e\u0437\u0432\u043e\u043b\u044f\u0435\u0442 \u043c\u043d\u0435 \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u0438\u0442\u044c \u043f\u0435\u0440\u0432\u044b\u0439 \u043a\u043b\u044e\u0447\u0435\u0432\u043e\u0439 \u043f\u0440\u0438\u043d\u0446\u0438\u043f \u0425\u0430\u0439\u043a\u0443.  \u0412\u0441\u0435 \u043c\u043e\u0434\u0443\u043b\u0438 \u0434\u043e\u043b\u0436\u043d\u044b \u0431\u044b\u0442\u044c \u043f\u043e\u0434\u043a\u043b\u0430\u0441\u0441\u043e\u043c <code>hk.Module<\/code>.  \u042d\u0442\u043e \u043e\u0437\u043d\u0430\u0447\u0430\u0435\u0442, \u0447\u0442\u043e \u043e\u043d\u0438 \u0434\u043e\u043b\u0436\u043d\u044b \u0440\u0435\u0430\u043b\u0438\u0437\u043e\u0432\u0430\u0442\u044c <code>__init__<\/code> \u0438 <code>__call__<\/code>, \u043d\u0430\u0440\u044f\u0434\u0443 \u0441 \u043b\u044e\u0431\u044b\u043c \u0434\u0440\u0443\u0433\u0438\u043c \u043c\u0435\u0442\u043e\u0434\u043e\u043c.  \u0412 \u043d\u0435\u043a\u043e\u0442\u043e\u0440\u043e\u043c \u0441\u043c\u044b\u0441\u043b\u0435 \u044d\u0442\u043e \u0442\u0430 \u0436\u0435 \u0430\u0440\u0445\u0438\u0442\u0435\u043a\u0442\u0443\u0440\u0430 \u0441 \u043c\u043e\u0434\u0443\u043b\u044f\u043c\u0438 Pytorch, \u0433\u0434\u0435 \u043c\u044b \u0440\u0435\u0430\u043b\u0438\u0437\u0443\u0435\u043c <code>__init__<\/code> \u0438 <code>forward<\/code>.<\/p>\n<p>\u0427\u0442\u043e\u0431\u044b \u0441\u0434\u0435\u043b\u0430\u0442\u044c \u044d\u0442\u043e \u043a\u0440\u0438\u0441\u0442\u0430\u043b\u044c\u043d\u043e \u044f\u0441\u043d\u044b\u043c, \u0434\u0430\u0432\u0430\u0439\u0442\u0435 \u0441\u043e\u0437\u0434\u0430\u0434\u0438\u043c \u043f\u0440\u043e\u0441\u0442\u043e\u0439 \u0434\u0432\u0443\u0445\u0441\u043b\u043e\u0439\u043d\u044b\u0439 \u043c\u043d\u043e\u0433\u043e\u0441\u043b\u043e\u0439\u043d\u044b\u0439 \u043f\u0435\u0440\u0441\u0435\u043f\u0442\u0440\u043e\u043d. <code>hk.Module<\/code>\u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u0443\u0434\u043e\u0431\u043d\u043e \u0431\u0443\u0434\u0435\u0442 \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u0442\u044c \u0432 Transformer \u043d\u0438\u0436\u0435.<\/p>\n<h2 id=\"the-linear-layer\">\u041b\u0438\u043d\u0435\u0439\u043d\u044b\u0439 \u0441\u043b\u043e\u0439<\/h2>\n<p>\u041f\u0440\u043e\u0441\u0442\u043e\u0439 \u0434\u0432\u0443\u0445\u0441\u043b\u043e\u0439\u043d\u044b\u0439 MLP \u0431\u0443\u0434\u0435\u0442 \u0432\u044b\u0433\u043b\u044f\u0434\u0435\u0442\u044c \u0442\u0430\u043a.  \u0415\u0449\u0435 \u0440\u0430\u0437, \u0432\u044b \u043c\u043e\u0436\u0435\u0442\u0435 \u0437\u0430\u043c\u0435\u0442\u0438\u0442\u044c, \u043a\u0430\u043a \u0437\u043d\u0430\u043a\u043e\u043c\u043e \u044d\u0442\u043e \u0432\u044b\u0433\u043b\u044f\u0434\u0438\u0442.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">DenseBlock<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Module<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"A 2-layer MLP\"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 init_scale<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">float<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 widening_factor<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">4<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 name<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Optional<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">str<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">super<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">name<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">name<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_init_scale <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> init_scale<\/span><\/p><p><span class=\"token plain\">        self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_widening_factor <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> widening_factor<\/span><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__call__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token operator\">&gt;<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        hiddens <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">shape<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token operator\">-<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        initializer <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">initializers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">VarianceScaling<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_init_scale<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Linear<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_widening_factor <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> hiddens<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> w_init<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">initializer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">gelu<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Linear<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">hiddens<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> w_init<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">initializer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<p>\u041d\u0435\u0441\u043a\u043e\u043b\u044c\u043a\u043e \u0432\u0435\u0449\u0435\u0439, \u043d\u0430 \u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u0441\u043b\u0435\u0434\u0443\u0435\u0442 \u043e\u0431\u0440\u0430\u0442\u0438\u0442\u044c \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435:<\/p>\n<ul>\n<li>\n<p>Haiku \u043f\u0440\u0435\u0434\u043e\u0441\u0442\u0430\u0432\u043b\u044f\u0435\u0442 \u043d\u0430\u043c \u043d\u0430\u0431\u043e\u0440 \u0438\u043d\u0438\u0446\u0438\u0430\u043b\u0438\u0437\u0430\u0442\u043e\u0440\u043e\u0432 \u0432\u0435\u0441\u043e\u0432 \u043f\u043e\u0434 <code>hk.initializers<\/code>\u0433\u0434\u0435 \u043c\u044b \u043c\u043e\u0436\u0435\u043c \u043d\u0430\u0439\u0442\u0438 \u043d\u0430\u0438\u0431\u043e\u043b\u0435\u0435 \u0440\u0430\u0441\u043f\u0440\u043e\u0441\u0442\u0440\u0430\u043d\u0435\u043d\u043d\u044b\u0435 \u043f\u043e\u0434\u0445\u043e\u0434\u044b.<\/p>\n<\/li>\n<li>\n<p>\u041e\u043d \u0442\u0430\u043a\u0436\u0435 \u0438\u043c\u0435\u0435\u0442 \u0432\u0441\u0442\u0440\u043e\u0435\u043d\u043d\u044b\u0435 \u043c\u043d\u043e\u0433\u0438\u0435 \u043f\u043e\u043f\u0443\u043b\u044f\u0440\u043d\u044b\u0435 \u0441\u043b\u043e\u0438 \u0438 \u043c\u043e\u0434\u0443\u043b\u0438, \u0442\u0430\u043a\u0438\u0435 \u043a\u0430\u043a <code>hk.Linear<\/code>.  \u041f\u043e\u043b\u043d\u044b\u0439 \u0441\u043f\u0438\u0441\u043e\u043a \u043c\u043e\u0436\u043d\u043e \u043d\u0430\u0439\u0442\u0438 \u0432 \u043e\u0444\u0438\u0446\u0438\u0430\u043b\u044c\u043d\u043e\u0439 \u0434\u043e\u043a\u0443\u043c\u0435\u043d\u0442\u0430\u0446\u0438\u0438.<\/p>\n<\/li>\n<li>\n<p>\u0424\u0443\u043d\u043a\u0446\u0438\u0438 \u0430\u043a\u0442\u0438\u0432\u0430\u0446\u0438\u0438 \u043d\u0435 \u043f\u0440\u0435\u0434\u043e\u0441\u0442\u0430\u0432\u043b\u044f\u044e\u0442\u0441\u044f, \u043f\u043e\u0441\u043a\u043e\u043b\u044c\u043a\u0443 \u0432 JAX \u0443\u0436\u0435 \u0435\u0441\u0442\u044c \u043f\u043e\u0434\u043f\u0430\u043a\u0435\u0442 \u0441 \u0438\u043c\u0435\u043d\u0435\u043c <code>jax.nn<\/code>\u0433\u0434\u0435 \u043c\u044b \u043c\u043e\u0436\u0435\u043c \u043d\u0430\u0439\u0442\u0438 \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u0430\u043a\u0442\u0438\u0432\u0430\u0446\u0438\u0438, \u0442\u0430\u043a\u0438\u0435 \u043a\u0430\u043a <code>relu<\/code> \u0438\u043b\u0438 <code>softmax<\/code>.<\/p>\n<\/li>\n<\/ul>\n<h2 id=\"the-normalization-layer\">\u0421\u043b\u043e\u0439 \u043d\u043e\u0440\u043c\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u0438<\/h2>\n<p>\u041d\u043e\u0440\u043c\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u044f \u0441\u043b\u043e\u0435\u0432 \u2014 \u0435\u0449\u0435 \u043e\u0434\u0438\u043d \u043d\u0435\u043e\u0442\u044a\u0435\u043c\u043b\u0435\u043c\u044b\u0439 \u0431\u043b\u043e\u043a \u0430\u0440\u0445\u0438\u0442\u0435\u043a\u0442\u0443\u0440\u044b \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u0442\u0435\u043b\u044f, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u043c\u044b \u0442\u0430\u043a\u0436\u0435 \u043c\u043e\u0436\u0435\u043c \u043d\u0430\u0439\u0442\u0438 \u0432 \u043e\u0431\u0449\u0438\u0445 \u043c\u043e\u0434\u0443\u043b\u044f\u0445 \u0432\u043d\u0443\u0442\u0440\u0438 Haiku.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">layer_norm<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> name<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Optional<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">str<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token operator\">&gt;<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"Apply a unique LayerNorm to x with default settings.\"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">LayerNorm<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">axis<\/span><span class=\"token operator\">=<\/span><span class=\"token operator\">-<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                        create_scale<\/span><span class=\"token operator\">=<\/span><span class=\"token boolean\">True<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                        create_offset<\/span><span class=\"token operator\">=<\/span><span class=\"token boolean\">True<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                        name<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">name<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<h2 id=\"the-transformer\">\u0422\u0440\u0430\u043d\u0441\u0444\u043e\u0440\u043c\u0435\u0440<\/h2>\n<p>\u0410 \u0442\u0435\u043f\u0435\u0440\u044c \u043e \u0445\u043e\u0440\u043e\u0448\u0435\u043c.  \u041d\u0438\u0436\u0435 \u0432\u044b \u043c\u043e\u0436\u0435\u0442\u0435 \u043d\u0430\u0439\u0442\u0438 \u043e\u0447\u0435\u043d\u044c \u0443\u043f\u0440\u043e\u0449\u0435\u043d\u043d\u044b\u0439 Transformer, \u0432 \u043a\u043e\u0442\u043e\u0440\u043e\u043c \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u044e\u0442\u0441\u044f \u043d\u0430\u0448\u0438 \u043f\u0440\u0435\u0434\u0443\u0441\u0442\u0430\u043d\u043e\u0432\u043b\u0435\u043d\u043d\u044b\u0435 \u043c\u043e\u0434\u0443\u043b\u0438.  \u0412\u043d\u0443\u0442\u0440\u0438 <code>__init__<\/code>, \u043c\u044b \u043e\u043f\u0440\u0435\u0434\u0435\u043b\u044f\u0435\u043c \u043e\u0441\u043d\u043e\u0432\u043d\u044b\u0435 \u043f\u0435\u0440\u0435\u043c\u0435\u043d\u043d\u044b\u0435, \u0442\u0430\u043a\u0438\u0435 \u043a\u0430\u043a \u043a\u043e\u043b\u0438\u0447\u0435\u0441\u0442\u0432\u043e \u0441\u043b\u043e\u0435\u0432, \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435 \u0438 \u043f\u0440\u043e\u0446\u0435\u043d\u0442 \u043e\u0442\u0441\u0435\u0432\u0430.  \u0412\u043d\u0443\u0442\u0440\u0438 <code>__call__<\/code>\u0441\u043e\u0441\u0442\u0430\u0432\u043b\u044f\u0435\u043c \u0441\u043f\u0438\u0441\u043e\u043a \u0431\u043b\u043e\u043a\u043e\u0432 \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e <code>for<\/code> \u043f\u0435\u0442\u043b\u044f.<\/p>\n<p>\u041a\u0430\u043a \u0432\u0438\u0434\u0438\u0442\u0435, \u043a\u0430\u0436\u0434\u044b\u0439 \u0431\u043b\u043e\u043a \u0432\u043a\u043b\u044e\u0447\u0430\u0435\u0442 \u0432 \u0441\u0435\u0431\u044f:<\/p>\n<p>\u0412 \u043a\u043e\u043d\u0446\u0435 \u043c\u044b \u0442\u0430\u043a\u0436\u0435 \u0434\u043e\u0431\u0430\u0432\u043b\u044f\u0435\u043c \u0444\u0438\u043d\u0430\u043b\u044c\u043d\u044b\u0439 \u0441\u043b\u043e\u0439 \u043d\u043e\u0440\u043c\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u0438.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">Transformer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Module<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"A transformer stack.\"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 num_heads<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 num_layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 dropout_rate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">float<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 name<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Optional<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">str<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">super<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">name<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">name<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_num_layers <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> num_layers<\/span><\/p><p><span class=\"token plain\">        self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_num_heads <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> num_heads<\/span><\/p><p><span class=\"token plain\">        self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_dropout_rate <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> dropout_rate<\/span><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__call__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 h<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 mask<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Optional<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token plain\">jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 is_training<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">bool<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token operator\">&gt;<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"Connects the transformer.<\/span><\/p><p><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">        Args:<\/span><\/p><p><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">          h: Inputs, [B, T, H].<\/span><\/p><p><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">          mask: Padding mask, [B, T].<\/span><\/p><p><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">          is_training: Whether we're training or not.<\/span><\/p><p><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">        Returns:<\/span><\/p><p><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">          Array of shape [B, T, H].<\/span><\/p><p><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">        \"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        init_scale <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">\/<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_num_layers<\/span><\/p><p><span class=\"token plain\">        dropout_rate <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_dropout_rate <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">if<\/span><span class=\"token plain\"> is_training <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">else<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">if<\/span><span class=\"token plain\"> mask <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">is<\/span><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">not<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            mask <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> mask<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">for<\/span><span class=\"token plain\"> i <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">in<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">range<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_num_layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            h_norm <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> layer_norm<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">h<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> name<\/span><span class=\"token operator\">=<\/span><span class=\"token string-interpolation string\" style=\"color:rgb(255, 121, 198)\">f'h<\/span><span class=\"token string-interpolation interpolation punctuation\" style=\"color:rgb(248, 248, 242)\">{<\/span><span class=\"token string-interpolation interpolation\">i<\/span><span class=\"token string-interpolation interpolation punctuation\" style=\"color:rgb(248, 248, 242)\">}<\/span><span class=\"token string-interpolation string\" style=\"color:rgb(255, 121, 198)\">_ln_1'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            h_attn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> SelfAttention<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                num_heads<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_num_heads<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                key_size<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">64<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                w_init_scale<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">init_scale<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                name<\/span><span class=\"token operator\">=<\/span><span class=\"token string-interpolation string\" style=\"color:rgb(255, 121, 198)\">f'h<\/span><span class=\"token string-interpolation interpolation punctuation\" style=\"color:rgb(248, 248, 242)\">{<\/span><span class=\"token string-interpolation interpolation\">i<\/span><span class=\"token string-interpolation interpolation punctuation\" style=\"color:rgb(248, 248, 242)\">}<\/span><span class=\"token string-interpolation string\" style=\"color:rgb(255, 121, 198)\">_attn'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">h_norm<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mask<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">mask<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            h_attn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">dropout<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">next_rng_key<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> dropout_rate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> h_attn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            h <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> h <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> h_attn<\/span><\/p><p><span class=\"token plain\">            h_norm <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> layer_norm<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">h<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> name<\/span><span class=\"token operator\">=<\/span><span class=\"token string-interpolation string\" style=\"color:rgb(255, 121, 198)\">f'h<\/span><span class=\"token string-interpolation interpolation punctuation\" style=\"color:rgb(248, 248, 242)\">{<\/span><span class=\"token string-interpolation interpolation\">i<\/span><span class=\"token string-interpolation interpolation punctuation\" style=\"color:rgb(248, 248, 242)\">}<\/span><span class=\"token string-interpolation string\" style=\"color:rgb(255, 121, 198)\">_ln_2'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            h_dense <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> DenseBlock<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">init_scale<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> name<\/span><span class=\"token operator\">=<\/span><span class=\"token string-interpolation string\" style=\"color:rgb(255, 121, 198)\">f'h<\/span><span class=\"token string-interpolation interpolation punctuation\" style=\"color:rgb(248, 248, 242)\">{<\/span><span class=\"token string-interpolation interpolation\">i<\/span><span class=\"token string-interpolation interpolation punctuation\" style=\"color:rgb(248, 248, 242)\">}<\/span><span class=\"token string-interpolation string\" style=\"color:rgb(255, 121, 198)\">_mlp'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">h_norm<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            h_dense <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">dropout<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">next_rng_key<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> dropout_rate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> h_dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            h <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> h <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> h_dense<\/span><\/p><p><span class=\"token plain\">        h <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> layer_norm<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">h<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> name<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'ln_f'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> h<\/span><\/p><\/pre>\n<p>\u042f \u0434\u0443\u043c\u0430\u044e, \u0432\u044b \u0443\u0436\u0435 \u043f\u043e\u043d\u044f\u043b\u0438, \u0447\u0442\u043e \u043f\u043e\u0441\u0442\u0440\u043e\u0438\u0442\u044c \u043d\u0435\u0439\u0440\u043e\u043d\u043d\u0443\u044e \u0441\u0435\u0442\u044c \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e JAX \u043e\u0447\u0435\u043d\u044c \u043f\u0440\u043e\u0441\u0442\u043e.<\/p>\n<h2 id=\"the-embeddings-layer\">\u0421\u043b\u043e\u0439 \u0432\u0441\u0442\u0440\u0430\u0438\u0432\u0430\u043d\u0438\u044f<\/h2>\n<p>\u0414\u043b\u044f \u0437\u0430\u0432\u0435\u0440\u0448\u0435\u043d\u0438\u044f \u0434\u0430\u0432\u0430\u0439\u0442\u0435 \u0442\u0430\u043a\u0436\u0435 \u0432\u043a\u043b\u044e\u0447\u0438\u043c \u0441\u043b\u043e\u0439 \u0432\u0441\u0442\u0440\u0430\u0438\u0432\u0430\u043d\u0438\u044f.  \u041f\u043e\u043b\u0435\u0437\u043d\u043e \u0437\u043d\u0430\u0442\u044c, \u0447\u0442\u043e Haiku \u0442\u0430\u043a\u0436\u0435 \u043f\u0440\u0435\u0434\u043e\u0441\u0442\u0430\u0432\u043b\u044f\u0435\u0442 \u0441\u043b\u043e\u0439 \u0432\u043d\u0435\u0434\u0440\u0435\u043d\u0438\u044f, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u0431\u0443\u0434\u0435\u0442 \u0441\u043e\u0437\u0434\u0430\u0432\u0430\u0442\u044c \u0442\u043e\u043a\u0435\u043d\u044b \u0438\u0437 \u043d\u0430\u0448\u0435\u0433\u043e \u0432\u0445\u043e\u0434\u043d\u043e\u0433\u043e \u043f\u0440\u0435\u0434\u043b\u043e\u0436\u0435\u043d\u0438\u044f.  \u0417\u0430\u0442\u0435\u043c \u0442\u043e\u043a\u0435\u043d \u0434\u043e\u0431\u0430\u0432\u043b\u044f\u0435\u0442\u0441\u044f \u043a \u043f\u043e\u0437\u0438\u0446\u0438\u043e\u043d\u043d\u044b\u043c \u0432\u043b\u043e\u0436\u0435\u043d\u0438\u044f\u043c, \u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u043f\u0440\u043e\u0438\u0437\u0432\u043e\u0434\u044f\u0442 \u043e\u043a\u043e\u043d\u0447\u0430\u0442\u0435\u043b\u044c\u043d\u044b\u0439 \u0432\u0432\u043e\u0434.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">embeddings<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Mapping<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">str<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    tokens <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'obs'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    input_mask <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">greater<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">tokens<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    seq_length <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> tokens<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">shape<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    embed_init <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">initializers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">TruncatedNormal<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">stddev<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">0.02<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    token_embedding_map <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Embed<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> d_model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> w_init<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">embed_init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    token_embs <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> token_embedding_map<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">tokens<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    positional_embeddings <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">get_parameter<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'pos_embs'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token plain\">seq_length<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> d_model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> init<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">embed_init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    input_embeddings <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> token_embs <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> positional_embeddings<\/span><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> input_embeddings<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> input_mask<\/span><\/p><\/pre>\n<p><code>hk.get_parameter(param_name, ...)<\/code>  \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u0435\u0442\u0441\u044f \u0434\u043b\u044f \u0434\u043e\u0441\u0442\u0443\u043f\u0430 \u043a \u043e\u0431\u0443\u0447\u0430\u0435\u043c\u044b\u043c \u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u0430\u043c \u043c\u043e\u0434\u0443\u043b\u044f.  \u041d\u043e \u0432\u044b \u043c\u043e\u0436\u0435\u0442\u0435 \u0441\u043f\u0440\u043e\u0441\u0438\u0442\u044c, \u043f\u043e\u0447\u0435\u043c\u0443 \u0431\u044b \u043f\u0440\u043e\u0441\u0442\u043e \u043d\u0435 \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u0442\u044c \u0441\u0432\u043e\u0439\u0441\u0442\u0432\u0430 \u043e\u0431\u044a\u0435\u043a\u0442\u0430, \u043a\u0430\u043a \u043c\u044b \u044d\u0442\u043e \u0434\u0435\u043b\u0430\u0435\u043c \u0432 Pytorch.  \u0417\u0434\u0435\u0441\u044c \u0432\u0441\u0442\u0443\u043f\u0430\u0435\u0442 \u0432 \u0434\u0435\u0439\u0441\u0442\u0432\u0438\u0435 \u0432\u0442\u043e\u0440\u043e\u0439 \u043a\u043b\u044e\u0447\u0435\u0432\u043e\u0439 \u043f\u0440\u0438\u043d\u0446\u0438\u043f \u0445\u0430\u0439\u043a\u0443. <strong>\u041c\u044b \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u0435\u043c \u044d\u0442\u043e\u0442 API, \u0447\u0442\u043e\u0431\u044b \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u0442\u044c \u043a\u043e\u0434 \u0432 \u0447\u0438\u0441\u0442\u0443\u044e \u0444\u0443\u043d\u043a\u0446\u0438\u044e, \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u044f<\/strong> <code>hk.transform<\/code>.  \u042d\u0442\u043e \u043d\u0435 \u043e\u0447\u0435\u043d\u044c \u043f\u0440\u043e\u0441\u0442\u043e \u043f\u043e\u043d\u044f\u0442\u044c, \u043d\u043e \u044f \u043f\u043e\u0441\u0442\u0430\u0440\u0430\u044e\u0441\u044c \u0441\u0434\u0435\u043b\u0430\u0442\u044c \u044d\u0442\u043e \u043a\u0430\u043a \u043c\u043e\u0436\u043d\u043e \u0431\u043e\u043b\u0435\u0435 \u044f\u0441\u043d\u044b\u043c.<\/p>\n<h2 id=\"why-pure-functions\">\u041f\u043e\u0447\u0435\u043c\u0443 \u0447\u0438\u0441\u0442\u044b\u0435 \u0444\u0443\u043d\u043a\u0446\u0438\u0438?<\/h2>\n<p>\u0421\u0438\u043b\u0430 JAX \u0437\u0430\u043a\u043b\u044e\u0447\u0430\u0435\u0442\u0441\u044f \u0432 \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u043d\u0438\u044f\u0445 \u0444\u0443\u043d\u043a\u0446\u0438\u0439: \u0432\u043e\u0437\u043c\u043e\u0436\u043d\u043e\u0441\u0442\u044c \u0432\u0435\u043a\u0442\u043e\u0440\u0438\u0437\u043e\u0432\u0430\u0442\u044c \u0444\u0443\u043d\u043a\u0446\u0438\u044e \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e <code>vmap<\/code>\u0430\u0432\u0442\u043e\u043c\u0430\u0442\u0438\u0447\u0435\u0441\u043a\u043e\u0435 \u0440\u0430\u0441\u043f\u0430\u0440\u0430\u043b\u043b\u0435\u043b\u0438\u0432\u0430\u043d\u0438\u0435 \u0441 <code>pmap<\/code>\u0441\u0432\u043e\u0435\u0432\u0440\u0435\u043c\u0435\u043d\u043d\u0430\u044f \u043a\u043e\u043c\u043f\u0438\u043b\u044f\u0446\u0438\u044f \u0441 <code>jit<\/code>.  \u041f\u0440\u0435\u0434\u043e\u0441\u0442\u0435\u0440\u0435\u0436\u0435\u043d\u0438\u0435 \u0437\u0434\u0435\u0441\u044c \u0437\u0430\u043a\u043b\u044e\u0447\u0430\u0435\u0442\u0441\u044f \u0432 \u0442\u043e\u043c, \u0447\u0442\u043e \u0434\u043b\u044f \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u043d\u0438\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u043e\u043d\u0430 \u0434\u043e\u043b\u0436\u043d\u0430 \u0431\u044b\u0442\u044c \u0447\u0438\u0441\u0442\u043e\u0439.<\/p>\n<p>\u0410 <strong>\u0447\u0438\u0441\u0442\u0430\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u044f<\/strong> \u044d\u0442\u043e \u0444\u0443\u043d\u043a\u0446\u0438\u044f, \u043e\u0431\u043b\u0430\u0434\u0430\u044e\u0449\u0430\u044f \u0441\u043b\u0435\u0434\u0443\u044e\u0449\u0438\u043c\u0438 \u0441\u0432\u043e\u0439\u0441\u0442\u0432\u0430\u043c\u0438:<\/p>\n<ul>\n<li>\n<p>\u0412\u043e\u0437\u0432\u0440\u0430\u0449\u0430\u0435\u043c\u044b\u0435 \u0437\u043d\u0430\u0447\u0435\u043d\u0438\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u0438\u0434\u0435\u043d\u0442\u0438\u0447\u043d\u044b \u0434\u043b\u044f \u0438\u0434\u0435\u043d\u0442\u0438\u0447\u043d\u044b\u0445 \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u043e\u0432 (\u043d\u0438\u043a\u0430\u043a\u0438\u0445 \u0440\u0430\u0437\u043b\u0438\u0447\u0438\u0439 \u0441 \u043b\u043e\u043a\u0430\u043b\u044c\u043d\u044b\u043c\u0438 \u0441\u0442\u0430\u0442\u0438\u0447\u0435\u0441\u043a\u0438\u043c\u0438 \u043f\u0435\u0440\u0435\u043c\u0435\u043d\u043d\u044b\u043c\u0438, \u043d\u0435\u043b\u043e\u043a\u0430\u043b\u044c\u043d\u044b\u043c\u0438 \u043f\u0435\u0440\u0435\u043c\u0435\u043d\u043d\u044b\u043c\u0438, \u0438\u0437\u043c\u0435\u043d\u044f\u0435\u043c\u044b\u043c\u0438 \u0441\u0441\u044b\u043b\u043e\u0447\u043d\u044b\u043c\u0438 \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u0430\u043c\u0438 \u0438\u043b\u0438 \u0432\u0445\u043e\u0434\u043d\u044b\u043c\u0438 \u043f\u043e\u0442\u043e\u043a\u0430\u043c\u0438).<\/p>\n<\/li>\n<li>\n<p>\u041f\u0440\u0438\u043b\u043e\u0436\u0435\u043d\u0438\u0435 \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u043d\u0435 \u0438\u043c\u0435\u0435\u0442 \u043f\u043e\u0431\u043e\u0447\u043d\u044b\u0445 \u044d\u0444\u0444\u0435\u043a\u0442\u043e\u0432 (\u0431\u0435\u0437 \u0438\u0437\u043c\u0435\u043d\u0435\u043d\u0438\u044f \u043b\u043e\u043a\u0430\u043b\u044c\u043d\u044b\u0445 \u0441\u0442\u0430\u0442\u0438\u0447\u0435\u0441\u043a\u0438\u0445 \u043f\u0435\u0440\u0435\u043c\u0435\u043d\u043d\u044b\u0445, \u043d\u0435\u043b\u043e\u043a\u0430\u043b\u044c\u043d\u044b\u0445 \u043f\u0435\u0440\u0435\u043c\u0435\u043d\u043d\u044b\u0445, \u0438\u0437\u043c\u0435\u043d\u044f\u0435\u043c\u044b\u0445 \u0441\u0441\u044b\u043b\u043e\u0447\u043d\u044b\u0445 \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u043e\u0432 \u0438\u043b\u0438 \u043f\u043e\u0442\u043e\u043a\u043e\u0432 \u0432\u0432\u043e\u0434\u0430\/\u0432\u044b\u0432\u043e\u0434\u0430).<\/p>\n<\/li>\n<\/ul>\n<p><span class=\"gatsby-resp-image-wrapper\" style=\"position:relative;display:block;margin-left:auto;margin-right:auto;max-width:716px\"><\/p>\n<p>    <span class=\"gatsby-resp-image-background-image\" style=\"padding-bottom:36.333333333333336%;position:relative;bottom:0;left:0;background-image:url('data:image\/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAHCAYAAAAIy204AAAACXBIWXMAAAsTAAALEwEAmpwYAAAA+klEQVQoz3WRi26DMAxF+f\/vG491UA1oiwZ0vMIzxHd2BBVIK5J1Hd1jOw4OAYshWldjNOspAOh+GHTdNDYkZ956R37dVDyHE4M3X9t1aNsOXGxDzk3bvsPBveBwV5MXBZI0JVZ6ZBndHxnxjYgbsI1TdErRT56T8EVZWk1vNyqfT+s7i9YmThJ8uC5F1ys8P8BXGMIPAgzjCAbt+fNygQxmHmHEnOcjYhXP9Tx8xzF4GI4r0\/H64zRZQFZVfQ+leuh15XzANM\/\/riy+XVl250IS3XLbvarrU\/G8LPitqhdz5EWlxt5wM14Nd1NAfkv7IyRkwD5sZ44q8Qe2ER9h0OVQvgAAAABJRU5ErkJggg==');background-size:cover;display:block\"\/><br \/>\n  <img decoding=\"async\" class=\"gatsby-resp-image-image\" alt=\"\u0447\u0438\u0441\u0442\u0430\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u044f\" title=\"\u0447\u0438\u0441\u0442\u0430\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u044f\" src=\"https:\/\/theaisummer.com\/static\/3c20387ac7cc2e6f7086dfecce73fd28\/6bbf7\/pure-function.png\" srcset=\"\/static\/3c20387ac7cc2e6f7086dfecce73fd28\/5a46d\/pure-function.png 300w,\/static\/3c20387ac7cc2e6f7086dfecce73fd28\/0a47e\/pure-function.png 600w,\/static\/3c20387ac7cc2e6f7086dfecce73fd28\/6bbf7\/pure-function.png 716w\" sizes=\"(max-width: 716px) 100vw, 716px\" style=\"width:100%;height:100%;margin:0;vertical-align:middle;position:absolute;top:0;left:0\" loading=\"lazy\"\/><\/p>\n<p>    <\/span><br \/>\n<em>\u0418\u0441\u0442\u043e\u0447\u043d\u0438\u043a: \u0447\u0438\u0441\u0442\u044b\u0435 \u0444\u0443\u043d\u043a\u0446\u0438\u0438 Scala \u043e\u0442 O&#8217;Reily.<\/em><\/p>\n<p>\u041f\u0440\u0430\u043a\u0442\u0438\u0447\u0435\u0441\u043a\u0438 \u044d\u0442\u043e \u043e\u0437\u043d\u0430\u0447\u0430\u0435\u0442, \u0447\u0442\u043e \u0447\u0438\u0441\u0442\u0430\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u044f \u0432\u0441\u0435\u0433\u0434\u0430 \u0431\u0443\u0434\u0435\u0442:<\/p>\n<ul>\n<li>\n<p><strong>\u0432\u0435\u0440\u043d\u0443\u0442\u044c \u0442\u043e\u0442 \u0436\u0435 \u0440\u0435\u0437\u0443\u043b\u044c\u0442\u0430\u0442, \u0435\u0441\u043b\u0438 \u043e\u043d \u0432\u044b\u0437\u044b\u0432\u0430\u0435\u0442\u0441\u044f \u0441 \u0442\u0435\u043c\u0438 \u0436\u0435 \u0432\u0445\u043e\u0434\u043d\u044b\u043c\u0438 \u0434\u0430\u043d\u043d\u044b\u043c\u0438<\/strong><\/p>\n<\/li>\n<li>\n<p><strong>\u0432\u0441\u0435 \u0432\u0445\u043e\u0434\u043d\u044b\u0435 \u0434\u0430\u043d\u043d\u044b\u0435 \u043f\u0435\u0440\u0435\u0434\u0430\u044e\u0442\u0441\u044f \u0447\u0435\u0440\u0435\u0437 \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u044b \u0444\u0443\u043d\u043a\u0446\u0438\u0438, \u0432\u0441\u0435 \u0440\u0435\u0437\u0443\u043b\u044c\u0442\u0430\u0442\u044b \u0432\u044b\u0432\u043e\u0434\u044f\u0442\u0441\u044f \u0447\u0435\u0440\u0435\u0437 \u0440\u0435\u0437\u0443\u043b\u044c\u0442\u0430\u0442\u044b \u0444\u0443\u043d\u043a\u0446\u0438\u0438<\/strong><\/p>\n<\/li>\n<\/ul>\n<p>Haiku \u043f\u0440\u0435\u0434\u043e\u0441\u0442\u0430\u0432\u043b\u044f\u0435\u0442 \u0444\u0443\u043d\u043a\u0446\u0438\u043e\u043d\u0430\u043b\u044c\u043d\u0443\u044e \u0442\u0440\u0430\u043d\u0441\u0444\u043e\u0440\u043c\u0430\u0446\u0438\u044e, \u043d\u0430\u0437\u044b\u0432\u0430\u0435\u043c\u0443\u044e <code>hk.transform<\/code>, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u043f\u0440\u0435\u0432\u0440\u0430\u0449\u0430\u0435\u0442 \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u0441 \u043e\u0431\u044a\u0435\u043a\u0442\u043d\u043e-\u043e\u0440\u0438\u0435\u043d\u0442\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u044b\u043c\u0438, \u0444\u0443\u043d\u043a\u0446\u0438\u043e\u043d\u0430\u043b\u044c\u043d\u043e \u00ab\u043d\u0435\u0447\u0438\u0441\u0442\u044b\u043c\u0438\u00bb \u043c\u043e\u0434\u0443\u043b\u044f\u043c\u0438 \u0432 \u0447\u0438\u0441\u0442\u044b\u0435 \u0444\u0443\u043d\u043a\u0446\u0438\u0438, \u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u043c\u043e\u0436\u043d\u043e \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u0442\u044c \u0441 JAX.  \u0427\u0442\u043e\u0431\u044b \u0443\u0432\u0438\u0434\u0435\u0442\u044c \u044d\u0442\u043e \u043d\u0430 \u043f\u0440\u0430\u043a\u0442\u0438\u043a\u0435, \u0434\u0430\u0432\u0430\u0439\u0442\u0435 \u043f\u0440\u043e\u0434\u043e\u043b\u0436\u0438\u043c \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u0435 \u043d\u0430\u0448\u0435\u0439 \u043c\u043e\u0434\u0435\u043b\u0438 Transformer.<\/p>\n<h2 id=\"the-forward-pass\">\u041f\u0440\u043e\u0445\u043e\u0434 \u0432\u043f\u0435\u0440\u0435\u0434<\/h2>\n<p>\u0422\u0438\u043f\u0438\u0447\u043d\u044b\u0439 \u043f\u0440\u044f\u043c\u043e\u0439 \u043f\u0440\u043e\u0445\u043e\u0434 \u0432\u043a\u043b\u044e\u0447\u0430\u0435\u0442 \u0432 \u0441\u0435\u0431\u044f:<\/p>\n<ol>\n<li>\n<p>\u041f\u043e\u043b\u0443\u0447\u0435\u043d\u0438\u0435 \u0432\u0445\u043e\u0434\u043d\u044b\u0445 \u0434\u0430\u043d\u043d\u044b\u0445 \u0438 \u0432\u044b\u0447\u0438\u0441\u043b\u0435\u043d\u0438\u0435 \u0432\u0445\u043e\u0434\u043d\u043e\u0433\u043e \u0432\u0441\u0442\u0440\u0430\u0438\u0432\u0430\u043d\u0438\u044f<\/p>\n<\/li>\n<li>\n<p>\u041f\u0440\u043e\u0431\u0435\u0433\u0438\u0442\u0435 \u0431\u043b\u043e\u043a\u0438 \u0422\u0440\u0430\u043d\u0441\u0444\u043e\u0440\u043c\u0435\u0440\u0430<\/p>\n<\/li>\n<li>\n<p>\u0412\u0435\u0440\u043d\u0443\u0442\u044c \u0432\u044b\u0432\u043e\u0434<\/p>\n<\/li>\n<\/ol>\n<p>\u0412\u044b\u0448\u0435\u0443\u043f\u043e\u043c\u044f\u043d\u0443\u0442\u044b\u0435 \u0448\u0430\u0433\u0438 \u043c\u043e\u0436\u043d\u043e \u043b\u0435\u0433\u043a\u043e \u0441\u043e\u0441\u0442\u0430\u0432\u0438\u0442\u044c \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e JAX \u0441\u043b\u0435\u0434\u0443\u044e\u0449\u0438\u043c \u043e\u0431\u0440\u0430\u0437\u043e\u043c:<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">build_forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> d_model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> num_heads<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                     num_layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> dropout_rate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">float<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"Create the model's forward pass.\"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Mapping<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">str<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                   is_training<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">bool<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">True<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token operator\">&gt;<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"Forward pass.\"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        input_embeddings<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> input_mask <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> embeddings<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        transformer <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> Transformer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            num_heads<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">num_heads<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> num_layers<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">num_layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> dropout_rate<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">dropout_rate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        output_embeddings <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> transformer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">input_embeddings<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> input_mask<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> is_training<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Linear<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">output_embeddings<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> forward_fn<\/span><\/p><\/pre>\n<p>\u0425\u043e\u0442\u044f \u043a\u043e\u0434 \u043f\u0440\u043e\u0441\u0442, \u0435\u0433\u043e \u0441\u0442\u0440\u0443\u043a\u0442\u0443\u0440\u0430 \u043c\u043e\u0436\u0435\u0442 \u043f\u043e\u043a\u0430\u0437\u0430\u0442\u044c\u0441\u044f \u043d\u0435\u043c\u043d\u043e\u0433\u043e \u0441\u0442\u0440\u0430\u043d\u043d\u043e\u0439.  \u0424\u0430\u043a\u0442\u0438\u0447\u0435\u0441\u043a\u0438\u0439 \u043f\u0440\u044f\u043c\u043e\u0439 \u043f\u0440\u043e\u0445\u043e\u0434 \u0432\u044b\u043f\u043e\u043b\u043d\u044f\u0435\u0442\u0441\u044f \u0447\u0435\u0440\u0435\u0437 <code>forward_fn<\/code> \u0444\u0443\u043d\u043a\u0446\u0438\u044f.  \u0422\u0435\u043c \u043d\u0435 \u043c\u0435\u043d\u0435\u0435, \u043c\u044b \u043e\u0431\u0435\u0440\u043d\u0435\u043c \u044d\u0442\u043e \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e <code>build_forward_fn<\/code> \u0444\u0443\u043d\u043a\u0446\u0438\u044f, \u043a\u043e\u0442\u043e\u0440\u0430\u044f \u0432\u043e\u0437\u0432\u0440\u0430\u0449\u0430\u0435\u0442 <code>forward_fn<\/code>.  \u041a\u0430\u043a\u043e\u0433\u043e \u0447\u0435\u0440\u0442\u0430?<\/p>\n<p>\u0412 \u0434\u0430\u043b\u044c\u043d\u0435\u0439\u0448\u0435\u043c \u043d\u0430\u043c \u043d\u0443\u0436\u043d\u043e \u0431\u0443\u0434\u0435\u0442 \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u0442\u044c <code>forward_fn<\/code> \u0444\u0443\u043d\u043a\u0446\u0438\u044e \u0432 \u0447\u0438\u0441\u0442\u0443\u044e \u0444\u0443\u043d\u043a\u0446\u0438\u044e, \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u044f <code>hk.transform<\/code> \u0447\u0442\u043e\u0431\u044b \u043c\u044b \u043c\u043e\u0433\u043b\u0438 \u0432\u043e\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u0442\u044c\u0441\u044f \u043f\u0440\u0435\u0438\u043c\u0443\u0449\u0435\u0441\u0442\u0432\u0430\u043c\u0438 \u0430\u0432\u0442\u043e\u043c\u0430\u0442\u0438\u0447\u0435\u0441\u043a\u043e\u0433\u043e \u0434\u0438\u0444\u0444\u0435\u0440\u0435\u043d\u0446\u0438\u0440\u043e\u0432\u0430\u043d\u0438\u044f, \u0440\u0430\u0441\u043f\u0430\u0440\u0430\u043b\u043b\u0435\u043b\u0438\u0432\u0430\u043d\u0438\u044f \u0438 \u0442. \u0434.<\/p>\n<p>\u042d\u0442\u043e \u0431\u0443\u0434\u0435\u0442 \u0434\u043e\u0441\u0442\u0438\u0433\u043d\u0443\u0442\u043e \u0437\u0430 \u0441\u0447\u0435\u0442:<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token plain\">forward_fn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> build_forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> d_model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> num_heads<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                                  num_layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> dropout_rate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">forward_fn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">transform<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<p>\u0412\u043e\u0442 \u043f\u043e\u0447\u0435\u043c\u0443 \u0432\u043c\u0435\u0441\u0442\u043e \u043f\u0440\u043e\u0441\u0442\u043e\u0433\u043e \u043e\u043f\u0440\u0435\u0434\u0435\u043b\u0435\u043d\u0438\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u043c\u044b \u0443\u043f\u0430\u043a\u043e\u0432\u044b\u0432\u0430\u0435\u043c \u0438 \u0432\u043e\u0437\u0432\u0440\u0430\u0449\u0430\u0435\u043c \u0441\u0430\u043c\u0443 \u0444\u0443\u043d\u043a\u0446\u0438\u044e \u0438\u043b\u0438 \u0432\u044b\u0437\u044b\u0432\u0430\u0435\u043c\u044b\u0439 \u043e\u0431\u044a\u0435\u043a\u0442, \u0435\u0441\u043b\u0438 \u0431\u044b\u0442\u044c \u0431\u043e\u043b\u0435\u0435 \u0442\u043e\u0447\u043d\u044b\u043c.  \u0417\u0430\u0442\u0435\u043c \u044d\u0442\u043e\u0442 \u0432\u044b\u0437\u044b\u0432\u0430\u0435\u043c\u044b\u0439 \u043e\u0431\u044a\u0435\u043a\u0442 \u043c\u043e\u0436\u0435\u0442 \u0431\u044b\u0442\u044c \u043f\u0435\u0440\u0435\u0434\u0430\u043d \u0432 <code>hk.transform<\/code> \u0438 \u0441\u0442\u0430\u0442\u044c \u0447\u0438\u0441\u0442\u043e\u0439 \u0444\u0443\u043d\u043a\u0446\u0438\u0435\u0439.  \u0415\u0441\u043b\u0438 \u044d\u0442\u043e \u043f\u043e\u043d\u044f\u0442\u043d\u043e, \u0434\u0430\u0432\u0430\u0439\u0442\u0435 \u043f\u0440\u043e\u0434\u043e\u043b\u0436\u0438\u043c \u043d\u0430\u0448\u0443 \u0444\u0443\u043d\u043a\u0446\u0438\u044e \u043f\u043e\u0442\u0435\u0440\u044c.<\/p>\n<h2 id=\"the-loss-function\">\u0424\u0443\u043d\u043a\u0446\u0438\u044f \u043f\u043e\u0442\u0435\u0440\u044c<\/h2>\n<p>\u0424\u0443\u043d\u043a\u0446\u0438\u044f \u043f\u043e\u0442\u0435\u0440\u044c \u2014 \u044d\u0442\u043e \u043d\u0430\u0448\u0430 \u0445\u043e\u0440\u043e\u0448\u043e \u0438\u0437\u0432\u0435\u0441\u0442\u043d\u0430\u044f \u043a\u0440\u043e\u0441\u0441-\u044d\u043d\u0442\u0440\u043e\u043f\u0438\u0439\u043d\u0430\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u044f \u0441 \u0442\u043e\u0439 \u0440\u0430\u0437\u043d\u0438\u0446\u0435\u0439, \u0447\u0442\u043e \u043c\u044b \u0442\u0430\u043a\u0436\u0435 \u0443\u0447\u0438\u0442\u044b\u0432\u0430\u0435\u043c \u043c\u0430\u0441\u043a\u0443.  \u0418 \u0441\u043d\u043e\u0432\u0430 JAX \u043f\u0440\u0435\u0434\u043e\u0441\u0442\u0430\u0432\u043b\u044f\u0435\u0442 <code>one_hot<\/code> \u0438 <code>log_softmax<\/code> \u0444\u0443\u043d\u043a\u0446\u0438\u043e\u043d\u0430\u043b\u044c\u043d\u044b\u0435 \u0432\u043e\u0437\u043c\u043e\u0436\u043d\u043e\u0441\u0442\u0438.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">lm_loss_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">               vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">               params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">               rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">               data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Mapping<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">str<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">               is_training<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">bool<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token boolean\">True<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token operator\">&gt;<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"Compute the loss on data wrt params.\"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    logits <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> is_training<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    targets <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">one_hot<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'target'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">assert<\/span><span class=\"token plain\"> logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">shape <\/span><span class=\"token operator\">==<\/span><span class=\"token plain\"> targets<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">shape<\/span><\/p><p><span class=\"token plain\">    mask <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">greater<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'obs'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    loss <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token plain\">jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">sum<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">targets <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">log_softmax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> axis<\/span><span class=\"token operator\">=<\/span><span class=\"token operator\">-<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    loss <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">sum<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">loss <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> mask<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">\/<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">sum<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">mask<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> loss<\/span><\/p><\/pre>\n<p>\u0415\u0441\u043b\u0438 \u0432\u044b \u0432\u0441\u0435 \u0435\u0449\u0435 \u0441\u043e \u043c\u043d\u043e\u0439, \u0441\u0434\u0435\u043b\u0430\u0439\u0442\u0435 \u0433\u043b\u043e\u0442\u043e\u043a \u043a\u043e\u0444\u0435, \u043f\u043e\u0442\u043e\u043c\u0443 \u0447\u0442\u043e \u0441 \u044d\u0442\u043e\u0433\u043e \u043c\u043e\u043c\u0435\u043d\u0442\u0430 \u0432\u0441\u0435 \u0441\u0442\u0430\u043d\u0435\u0442 \u0441\u0435\u0440\u044c\u0435\u0437\u043d\u043e.  \u041f\u0440\u0438\u0448\u043b\u043e \u0432\u0440\u0435\u043c\u044f \u043f\u043e\u0441\u0442\u0440\u043e\u0438\u0442\u044c \u043d\u0430\u0448 \u0442\u0440\u0435\u043d\u0438\u0440\u043e\u0432\u043e\u0447\u043d\u044b\u0439 \u0446\u0438\u043a\u043b.<\/p>\n<h2 id=\"the-training-loop\">\u0422\u0440\u0435\u043d\u0438\u0440\u043e\u0432\u043e\u0447\u043d\u044b\u0439 \u0446\u0438\u043a\u043b<\/h2>\n<p>\u041f\u043e\u0441\u043a\u043e\u043b\u044c\u043a\u0443 \u043d\u0438 Jax, \u043d\u0438 Haiku \u043d\u0435 \u0438\u043c\u0435\u044e\u0442 \u0432\u0441\u0442\u0440\u043e\u0435\u043d\u043d\u044b\u0445 \u0444\u0443\u043d\u043a\u0446\u0438\u0439 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0446\u0438\u0438, \u043c\u044b \u0432\u043e\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u0435\u043c\u0441\u044f \u0434\u0440\u0443\u0433\u0438\u043c \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a\u043e\u043c, \u043d\u0430\u0437\u044b\u0432\u0430\u0435\u043c\u044b\u043c Optax.  \u041a\u0430\u043a \u0443\u043f\u043e\u043c\u0438\u043d\u0430\u043b\u043e\u0441\u044c \u0432 \u043d\u0430\u0447\u0430\u043b\u0435, Optax \u2014 \u044d\u0442\u043e \u043f\u0430\u043a\u0435\u0442 goto \u0434\u043b\u044f \u043e\u0431\u0440\u0430\u0431\u043e\u0442\u043a\u0438 \u0433\u0440\u0430\u0434\u0438\u0435\u043d\u0442\u043e\u0432.<\/p>\n<p>\u0412\u043e-\u043f\u0435\u0440\u0432\u044b\u0445, \u0432\u043e\u0442 \u043d\u0435\u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u0432\u0435\u0449\u0438, \u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u0432\u0430\u043c \u043d\u0443\u0436\u043d\u043e \u0437\u043d\u0430\u0442\u044c \u043e\u0431 Optax:<\/p>\n<p>\u041a\u043b\u044e\u0447\u0435\u0432\u044b\u043c \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u043d\u0438\u0435\u043c Optax \u044f\u0432\u043b\u044f\u0435\u0442\u0441\u044f <code>GradientTransformation<\/code>.  \u041f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u043d\u0438\u0435 \u043e\u043f\u0440\u0435\u0434\u0435\u043b\u044f\u0435\u0442\u0441\u044f \u0434\u0432\u0443\u043c\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u044f\u043c\u0438: <code>__init__<\/code> \u0438 <code>__update__<\/code>. <code>__init__<\/code> \u0438\u043d\u0438\u0446\u0438\u0430\u043b\u0438\u0437\u0438\u0440\u0443\u0435\u0442 \u0441\u043e\u0441\u0442\u043e\u044f\u043d\u0438\u0435 \u0438 <code>__update__<\/code> \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u0443\u0435\u0442 \u0433\u0440\u0430\u0434\u0438\u0435\u043d\u0442\u044b \u043e\u0442\u043d\u043e\u0441\u0438\u0442\u0435\u043b\u044c\u043d\u043e \u0441\u043e\u0441\u0442\u043e\u044f\u043d\u0438\u044f \u0438 \u0442\u0435\u043a\u0443\u0449\u0435\u0433\u043e \u0437\u043d\u0430\u0447\u0435\u043d\u0438\u044f \u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u043e\u0432<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token plain\">state <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">grads<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> state <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> update<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">grads<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> params<\/span><span class=\"token operator\">=<\/span><span class=\"token boolean\">None<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<p>\u0415\u0449\u0435 \u043e\u0434\u043d\u0430 \u0432\u0435\u0449\u044c, \u043a\u043e\u0442\u043e\u0440\u0443\u044e \u043d\u0443\u0436\u043d\u043e \u0437\u043d\u0430\u0442\u044c, \u043f\u0440\u0435\u0436\u0434\u0435 \u0447\u0435\u043c \u043c\u044b \u0443\u0432\u0438\u0434\u0438\u043c \u043a\u043e\u0434, \u044d\u0442\u043e \u0432\u0441\u0442\u0440\u043e\u0435\u043d\u043d\u044b\u0439 \u0432 Python <code>functools.partial<\/code> \u0444\u0443\u043d\u043a\u0446\u0438\u044f. <code>functools<\/code> package \u0438\u043c\u0435\u0435\u0442 \u0434\u0435\u043b\u043e \u0441 \u0444\u0443\u043d\u043a\u0446\u0438\u044f\u043c\u0438 \u0438 \u043e\u043f\u0435\u0440\u0430\u0446\u0438\u044f\u043c\u0438 \u0432\u044b\u0441\u0448\u0435\u0433\u043e \u043f\u043e\u0440\u044f\u0434\u043a\u0430 \u043d\u0430\u0434 \u0432\u044b\u0437\u044b\u0432\u0430\u0435\u043c\u044b\u043c\u0438 \u043e\u0431\u044a\u0435\u043a\u0442\u0430\u043c\u0438.<\/p>\n<blockquote>\n<p>\u0424\u0443\u043d\u043a\u0446\u0438\u044f \u043d\u0430\u0437\u044b\u0432\u0430\u0435\u0442\u0441\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u0435\u0439 \u0432\u044b\u0441\u0448\u0435\u0433\u043e \u043f\u043e\u0440\u044f\u0434\u043a\u0430, \u0435\u0441\u043b\u0438 \u043e\u043d\u0430 \u0441\u043e\u0434\u0435\u0440\u0436\u0438\u0442 \u0434\u0440\u0443\u0433\u0438\u0435 \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u0432 \u043a\u0430\u0447\u0435\u0441\u0442\u0432\u0435 \u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u0430 \u0438\u043b\u0438 \u0432\u043e\u0437\u0432\u0440\u0430\u0449\u0430\u0435\u0442 \u0444\u0443\u043d\u043a\u0446\u0438\u044e \u0432 \u043a\u0430\u0447\u0435\u0441\u0442\u0432\u0435 \u0432\u044b\u0432\u043e\u0434\u0430.<\/p>\n<\/blockquote>\n<p> <code>partial<\/code>, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u0442\u0430\u043a\u0436\u0435 \u043c\u043e\u0436\u043d\u043e \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u0442\u044c \u0432 \u043a\u0430\u0447\u0435\u0441\u0442\u0432\u0435 \u0430\u043d\u043d\u043e\u0442\u0430\u0446\u0438\u0438, \u0432\u043e\u0437\u0432\u0440\u0430\u0449\u0430\u0435\u0442 \u043d\u043e\u0432\u0443\u044e \u0444\u0443\u043d\u043a\u0446\u0438\u044e, \u043e\u0441\u043d\u043e\u0432\u0430\u043d\u043d\u0443\u044e \u043d\u0430 \u0438\u0441\u0445\u043e\u0434\u043d\u043e\u0439, \u043d\u043e \u0441 \u043c\u0435\u043d\u044c\u0448\u0438\u043c \u043a\u043e\u043b\u0438\u0447\u0435\u0441\u0442\u0432\u043e\u043c \u0438\u043b\u0438 \u0444\u0438\u043a\u0441\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u044b\u043c\u0438 \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u0430\u043c\u0438.  \u0415\u0441\u043b\u0438, \u043d\u0430\u043f\u0440\u0438\u043c\u0435\u0440, f \u0443\u043c\u043d\u043e\u0436\u0430\u0435\u0442 \u0434\u0432\u0430 \u0437\u043d\u0430\u0447\u0435\u043d\u0438\u044f x, y, \u043f\u0430\u0440\u0442\u0438\u0430\u043b \u0441\u043e\u0437\u0434\u0430\u0441\u0442 \u043d\u043e\u0432\u0443\u044e \u0444\u0443\u043d\u043a\u0446\u0438\u044e, \u0433\u0434\u0435 x \u0431\u0443\u0434\u0435\u0442 \u0444\u0438\u043a\u0441\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u044b\u043c \u0438 \u0440\u0430\u0432\u043d\u044b\u043c 2.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">from<\/span><span class=\"token plain\"> functools <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> partial<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">f<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\">y<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> x <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> y<\/span><\/p><p><span class=\"token plain\">g <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> partial<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">f<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">print<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">g<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">4<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><\/pre>\n<p>\u041f\u043e\u0441\u043b\u0435 \u044d\u0442\u043e\u0433\u043e \u043a\u043e\u0440\u043e\u0442\u043a\u043e\u0433\u043e \u043e\u0431\u0445\u043e\u0434\u0430 \u0434\u0430\u0432\u0430\u0439\u0442\u0435 \u043f\u0440\u043e\u0434\u043e\u043b\u0436\u0438\u043c.  \u0427\u0442\u043e\u0431\u044b \u0440\u0430\u0437\u0433\u0440\u0443\u0437\u0438\u0442\u044c \u043d\u0430\u0448 <code>main<\/code> \u043c\u044b \u0438\u0437\u0432\u043b\u0435\u0447\u0435\u043c \u043e\u0431\u043d\u043e\u0432\u043b\u0435\u043d\u0438\u0435 \u0433\u0440\u0430\u0434\u0438\u0435\u043d\u0442\u043e\u0432 \u0432 \u043e\u0442\u0434\u0435\u043b\u044c\u043d\u044b\u0439 \u043a\u043b\u0430\u0441\u0441.<\/p>\n<p>\u043f\u0440\u0435\u0436\u0434\u0435 \u0432\u0441\u0435\u0433\u043e <code>GradientUpdater<\/code> \u043f\u0440\u0438\u043d\u0438\u043c\u0430\u0435\u0442 \u043c\u043e\u0434\u0435\u043b\u044c, \u0444\u0443\u043d\u043a\u0446\u0438\u044e \u043f\u043e\u0442\u0435\u0440\u044c \u0438 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0442\u043e\u0440.<\/p>\n<ol>\n<li>\u041c\u043e\u0434\u0435\u043b\u044c \u0431\u0443\u0434\u0435\u0442 \u0447\u0438\u0441\u0442\u043e\u0439 <code>forward_fn<\/code> \u0444\u0443\u043d\u043a\u0446\u0438\u044f, \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u043d\u043d\u0430\u044f <code>hk.transform<\/code><\/li>\n<\/ol>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token plain\">forward_fn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> build_forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> d_model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> num_heads<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                                  num_layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> dropout_rate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">forward_fn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">transform<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<ol start=\"2\">\n<li>\u0424\u0443\u043d\u043a\u0446\u0438\u044f \u043f\u043e\u0442\u0435\u0440\u044c \u0431\u0443\u0434\u0435\u0442 \u0440\u0435\u0437\u0443\u043b\u044c\u0442\u0430\u0442\u043e\u043c \u0447\u0430\u0441\u0442\u0438\u0447\u043d\u043e\u0433\u043e \u0441 \u0444\u0438\u043a\u0441\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u044b\u043c <code>forward_fn<\/code> \u0438 `vocab_size<\/li>\n<\/ol>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token plain\">loss_fn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> functools<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">partial<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">lm_loss_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">apply<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<ol start=\"3\">\n<li>\u041e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0442\u043e\u0440 \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u044f\u0435\u0442 \u0441\u043e\u0431\u043e\u0439 \u043d\u0430\u0431\u043e\u0440 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0446\u0438\u043e\u043d\u043d\u044b\u0445 \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u043d\u0438\u0439, \u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u0431\u0443\u0434\u0443\u0442 \u0432\u044b\u043f\u043e\u043b\u043d\u044f\u0442\u044c\u0441\u044f \u043f\u043e\u0441\u043b\u0435\u0434\u043e\u0432\u0430\u0442\u0435\u043b\u044c\u043d\u043e (\u043e\u043f\u0435\u0440\u0430\u0446\u0438\u0438 \u043c\u043e\u0436\u043d\u043e \u043a\u043e\u043c\u0431\u0438\u043d\u0438\u0440\u043e\u0432\u0430\u0442\u044c \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e <code>optax.chain<\/code> )<\/li>\n<\/ol>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token plain\">optimizer <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> optax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">chain<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        optax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">clip_by_global_norm<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">grad_clip_value<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        optax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">adam<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">learning_rate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> b1<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">0.9<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> b2<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">0.99<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<p>\u0421\u0440\u0435\u0434\u0441\u0442\u0432\u043e \u043e\u0431\u043d\u043e\u0432\u043b\u0435\u043d\u0438\u044f Gradient \u0431\u0443\u0434\u0435\u0442 \u0438\u043d\u0438\u0446\u0438\u0430\u043b\u0438\u0437\u0438\u0440\u043e\u0432\u0430\u043d\u043e \u0441\u043b\u0435\u0434\u0443\u044e\u0449\u0438\u043c \u043e\u0431\u0440\u0430\u0437\u043e\u043c:<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token plain\">updater <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> GradientUpdater<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> loss_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<p>\u0438 \u0431\u0443\u0434\u0435\u0442 \u0432\u044b\u0433\u043b\u044f\u0434\u0435\u0442\u044c \u0442\u0430\u043a:<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">GradientUpdater<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"A stateless abstraction around an init_fn\/update_fn pair.<\/span><\/p><p><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">    This extracts some common boilerplate from the training loop.<\/span><\/p><p><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">    \"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> net_init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> loss_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                 optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> optax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">GradientTransformation<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_net_init <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> net_init<\/span><\/p><p><span class=\"token plain\">        self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_loss_fn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> loss_fn<\/span><\/p><p><span class=\"token plain\">        self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_opt <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> optimizer<\/span><\/p><p><span class=\"token plain\">    <\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">@functools<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">partial<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">jit<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> static_argnums<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> master_rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"Initializes state of the updater.\"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        out_rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> init_rng <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">random<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">split<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">master_rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        params <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_net_init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">init_rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        opt_state <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_opt<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        out <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">dict<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            step<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">np<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">array<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            rng<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">out_rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            opt_state<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">opt_state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            params<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> out<\/span><\/p><p><span class=\"token plain\">    <\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">@functools<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">partial<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">jit<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> static_argnums<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">update<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Mapping<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">str<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> Any<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> Mapping<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">str<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ndarray<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token triple-quoted-string string\" style=\"color:rgb(255, 121, 198)\">\"\"\"Updates the state using some data and returns metrics.\"\"\"<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> new_rng <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">random<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">split<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'rng'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        params <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'params'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> g <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">value_and_grad<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_loss_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        updates<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> opt_state <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_opt<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">update<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">g<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'opt_state'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        params <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> optax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">apply_updates<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> updates<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        new_state <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">{<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            <\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'step'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'step'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            <\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'rng'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> new_rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            <\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'opt_state'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> opt_state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            <\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'params'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">}<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        metrics <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">{<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            <\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'step'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'step'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">            <\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'loss'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">}<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> new_state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> metrics<\/span><\/p><\/pre>\n<p>\u0412\u043d\u0443\u0442\u0440\u0438 <code>__init__<\/code>\u043c\u044b \u0438\u043d\u0438\u0446\u0438\u0430\u043b\u0438\u0437\u0438\u0440\u0443\u0435\u043c \u043d\u0430\u0448 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0442\u043e\u0440 \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e <code>self._opt.init(params)<\/code> \u0438 \u043c\u044b \u043e\u0431\u044a\u044f\u0432\u043b\u044f\u0435\u043c \u0441\u043e\u0441\u0442\u043e\u044f\u043d\u0438\u0435 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0446\u0438\u0438.  \u0421\u043e\u0441\u0442\u043e\u044f\u043d\u0438\u0435 \u0431\u0443\u0434\u0435\u0442 \u0441\u043b\u043e\u0432\u0430\u0440\u0435\u043c \u0441:<\/p>\n<p> <code>update<\/code> \u0431\u0443\u0434\u0435\u0442 \u043e\u0431\u043d\u043e\u0432\u043b\u044f\u0442\u044c \u043a\u0430\u043a \u0441\u043e\u0441\u0442\u043e\u044f\u043d\u0438\u0435 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0442\u043e\u0440\u0430, \u0442\u0430\u043a \u0438 \u043e\u0431\u0443\u0447\u0430\u0435\u043c\u044b\u0435 \u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u044b.  \u0412 \u043a\u043e\u043d\u0446\u0435 \u043a\u043e\u043d\u0446\u043e\u0432, \u043e\u043d \u0432\u0435\u0440\u043d\u0435\u0442 \u043d\u043e\u0432\u043e\u0435 \u0441\u043e\u0441\u0442\u043e\u044f\u043d\u0438\u0435.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token plain\">updates<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> opt_state <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">_opt<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">update<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">g<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'opt_state'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> params <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> optax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">apply_updates<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> updates<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<p>\u0415\u0449\u0435 \u0434\u0432\u0435 \u0432\u0435\u0449\u0438, \u043d\u0430 \u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u0441\u043b\u0435\u0434\u0443\u0435\u0442 \u043e\u0431\u0440\u0430\u0442\u0438\u0442\u044c \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435:<\/p>\n<ul>\n<li>\n<p><code>jax.value_and_grad()<\/code>  \u044d\u0442\u043e \u0441\u043f\u0435\u0446\u0438\u0430\u043b\u044c\u043d\u0430\u044f \u0444\u0443\u043d\u043a\u0446\u0438\u044f, \u043a\u043e\u0442\u043e\u0440\u0430\u044f \u0432\u043e\u0437\u0432\u0440\u0430\u0449\u0430\u0435\u0442 \u0434\u0438\u0444\u0444\u0435\u0440\u0435\u043d\u0446\u0438\u0440\u0443\u0435\u043c\u0443\u044e \u0444\u0443\u043d\u043a\u0446\u0438\u044e \u0441 \u0435\u0435 \u0433\u0440\u0430\u0434\u0438\u0435\u043d\u0442\u0430\u043c\u0438<\/p>\n<\/li>\n<li>\n<p>\u041e\u0431\u0430 <code>__init__<\/code> \u0438 <code>__update__<\/code> \u0430\u043d\u043d\u043e\u0442\u0438\u0440\u043e\u0432\u0430\u043d\u044b <code>@functools.partial(jax.jit, static_argnums=0)<\/code>, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u0432\u044b\u0437\u043e\u0432\u0435\u0442 JIT-\u043a\u043e\u043c\u043f\u0438\u043b\u044f\u0442\u043e\u0440 \u0438 \u0441\u043a\u043e\u043c\u043f\u0438\u043b\u0438\u0440\u0443\u0435\u0442 \u0438\u0445 \u0432 XLA \u0432\u043e \u0432\u0440\u0435\u043c\u044f \u0432\u044b\u043f\u043e\u043b\u043d\u0435\u043d\u0438\u044f.  \u041e\u0431\u0440\u0430\u0442\u0438\u0442\u0435 \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435, \u0447\u0442\u043e \u0435\u0441\u043b\u0438 \u043c\u044b \u043d\u0435 \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u043b\u0438<code>forward_fn<\/code> \u0432 \u0447\u0438\u0441\u0442\u0443\u044e \u0444\u0443\u043d\u043a\u0446\u0438\u044e, \u044d\u0442\u043e \u0431\u044b\u043b\u043e \u0431\u044b \u043d\u0435\u0432\u043e\u0437\u043c\u043e\u0436\u043d\u043e.<\/p>\n<\/li>\n<\/ul>\n<p>\u041d\u0430\u043a\u043e\u043d\u0435\u0446, \u043c\u044b \u0433\u043e\u0442\u043e\u0432\u044b \u043f\u043e\u0441\u0442\u0440\u043e\u0438\u0442\u044c \u0432\u0435\u0441\u044c \u0446\u0438\u043a\u043b \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u043e\u0431\u044a\u0435\u0434\u0438\u043d\u044f\u0435\u0442 \u0432\u0441\u0435 \u0438\u0434\u0435\u0438 \u0438 \u043a\u043e\u0434, \u0443\u043f\u043e\u043c\u044f\u043d\u0443\u0442\u044b\u0435 \u0434\u043e \u0441\u0438\u0445 \u043f\u043e\u0440.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">main<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    train_dataset<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> vocab_size <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> load<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">batch_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                                     sequence_length<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    forward_fn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> build_forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> d_model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> num_heads<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">                                  num_layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> dropout_rate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    forward_fn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> hk<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">transform<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    loss_fn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> functools<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">partial<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">lm_loss_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">apply<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> vocab_size<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    optimizer <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> optax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">chain<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        optax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">clip_by_global_norm<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">grad_clip_value<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        optax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">adam<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">learning_rate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> b1<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">0.9<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> b2<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">0.99<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    updater <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> GradientUpdater<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">forward_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> loss_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    logging<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">info<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'Initializing parameters...'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    rng <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">random<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">PRNGKey<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">428<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    data <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">next<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">train_dataset<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    state <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> updater<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    logging<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">info<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'Starting train loop...'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    prev_time <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> time<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">time<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">    <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">for<\/span><span class=\"token plain\"> step <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">in<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">range<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">MAX_STEPS<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        data <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">next<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">train_dataset<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">        state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> metrics <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> updater<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">update<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">state<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<p>\u041e\u0431\u0440\u0430\u0442\u0438\u0442\u0435 \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435, \u043a\u0430\u043a \u043c\u044b \u0432\u043a\u043b\u044e\u0447\u0430\u0435\u043c <code>GradientUpdate<\/code>.  \u042d\u0442\u043e \u0432\u0441\u0435\u0433\u043e \u0434\u0432\u0435 \u0441\u0442\u0440\u043e\u0447\u043a\u0438 \u043a\u043e\u0434\u0430:<\/p>\n<ul>\n<li>\n<p><code>state = updater.init(rng, data)<\/code><\/p>\n<\/li>\n<li>\n<p><code>state, metrics = updater.update(state, data)<\/code><\/p>\n<\/li>\n<\/ul>\n<p>\u0412\u043e\u0442 \u0438 \u0432\u0441\u0435.  \u042f \u043d\u0430\u0434\u0435\u044e\u0441\u044c, \u0447\u0442\u043e \u0442\u0435\u043f\u0435\u0440\u044c \u0443 \u0432\u0430\u0441 \u0435\u0441\u0442\u044c \u0431\u043e\u043b\u0435\u0435 \u0447\u0435\u0442\u043a\u043e\u0435 \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u0438\u0435 \u043e JAX \u0438 \u0435\u0433\u043e \u0432\u043e\u0437\u043c\u043e\u0436\u043d\u043e\u0441\u0442\u044f\u0445.<\/p>\n<h2 id=\"acknowledgments\">\u0411\u043b\u0430\u0433\u043e\u0434\u0430\u0440\u043d\u043e\u0441\u0442\u0438<\/h2>\n<p>\u041f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u043d\u044b\u0439 \u043a\u043e\u0434 \u0441\u0438\u043b\u044c\u043d\u043e \u0432\u0434\u043e\u0445\u043d\u043e\u0432\u043b\u0435\u043d \u043e\u0444\u0438\u0446\u0438\u0430\u043b\u044c\u043d\u044b\u043c\u0438 \u043f\u0440\u0438\u043c\u0435\u0440\u0430\u043c\u0438 \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a\u0430 Haiku.  \u041e\u043d \u0431\u044b\u043b \u0438\u0437\u043c\u0435\u043d\u0435\u043d, \u0447\u0442\u043e\u0431\u044b \u0441\u043e\u043e\u0442\u0432\u0435\u0442\u0441\u0442\u0432\u043e\u0432\u0430\u0442\u044c \u043f\u043e\u0442\u0440\u0435\u0431\u043d\u043e\u0441\u0442\u044f\u043c \u044d\u0442\u043e\u0439 \u0441\u0442\u0430\u0442\u044c\u0438.  \u041f\u043e\u043b\u043d\u044b\u0439 \u0441\u043f\u0438\u0441\u043e\u043a \u043f\u0440\u0438\u043c\u0435\u0440\u043e\u0432 \u043c\u043e\u0436\u043d\u043e \u043d\u0430\u0439\u0442\u0438 \u0432 \u043e\u0444\u0438\u0446\u0438\u0430\u043b\u044c\u043d\u043e\u043c \u0440\u0435\u043f\u043e\u0437\u0438\u0442\u043e\u0440\u0438\u0438.<\/p>\n<h2 id=\"conclusion\">\u0417\u0430\u043a\u043b\u044e\u0447\u0435\u043d\u0438\u0435<\/h2>\n<p>\u0412 \u044d\u0442\u043e\u0439 \u0441\u0442\u0430\u0442\u044c\u0435 \u043c\u044b \u0443\u0432\u0438\u0434\u0435\u043b\u0438, \u043a\u0430\u043a \u043c\u043e\u0436\u043d\u043e \u0440\u0430\u0437\u0440\u0430\u0431\u043e\u0442\u0430\u0442\u044c \u0438 \u043e\u0431\u0443\u0447\u0438\u0442\u044c \u0432\u0430\u043d\u0438\u043b\u044c\u043d\u043e\u0433\u043e \u0442\u0440\u0430\u043d\u0441\u0444\u043e\u0440\u043c\u0435\u0440\u0430 \u0432 JAX \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e Haiku.  \u0425\u043e\u0442\u044f \u043a\u043e\u0434 \u043d\u0435 \u0432\u0441\u0435\u0433\u0434\u0430 \u0441\u043b\u043e\u0436\u0435\u043d \u0434\u043b\u044f \u043f\u043e\u043d\u0438\u043c\u0430\u043d\u0438\u044f, \u0435\u043c\u0443 \u0432\u0441\u0435 \u0436\u0435 \u043d\u0435 \u0445\u0432\u0430\u0442\u0430\u0435\u0442 \u0447\u0438\u0442\u0430\u0431\u0435\u043b\u044c\u043d\u043e\u0441\u0442\u0438 Pytorch \u0438\u043b\u0438 Tensorflow.  \u042f \u043d\u0430\u0441\u0442\u043e\u044f\u0442\u0435\u043b\u044c\u043d\u043e \u0440\u0435\u043a\u043e\u043c\u0435\u043d\u0434\u0443\u044e \u043f\u043e\u0438\u0433\u0440\u0430\u0442\u044c \u0441 \u043d\u0438\u043c, \u043e\u0442\u043a\u0440\u044b\u0442\u044c \u0434\u043b\u044f \u0441\u0435\u0431\u044f \u0441\u0438\u043b\u044c\u043d\u044b\u0435 \u0438 \u0441\u043b\u0430\u0431\u044b\u0435 \u0441\u0442\u043e\u0440\u043e\u043d\u044b JAX \u0438 \u043f\u043e\u0441\u043c\u043e\u0442\u0440\u0435\u0442\u044c, \u043f\u043e\u0434\u043e\u0439\u0434\u0435\u0442 \u043b\u0438 \u043e\u043d \u0434\u043b\u044f \u0432\u0430\u0448\u0435\u0433\u043e \u0441\u043b\u0435\u0434\u0443\u044e\u0449\u0435\u0433\u043e \u043f\u0440\u043e\u0435\u043a\u0442\u0430.  \u041f\u043e \u043c\u043e\u0435\u043c\u0443 \u043e\u043f\u044b\u0442\u0443, JAX \u043e\u0447\u0435\u043d\u044c \u0441\u0438\u043b\u0435\u043d \u0434\u043b\u044f \u0438\u0441\u0441\u043b\u0435\u0434\u043e\u0432\u0430\u0442\u0435\u043b\u044c\u0441\u043a\u0438\u0445 \u043f\u0440\u0438\u043b\u043e\u0436\u0435\u043d\u0438\u0439, \u0442\u0440\u0435\u0431\u0443\u044e\u0449\u0438\u0445 \u0432\u044b\u0441\u043e\u043a\u043e\u0439 \u043f\u0440\u043e\u0438\u0437\u0432\u043e\u0434\u0438\u0442\u0435\u043b\u044c\u043d\u043e\u0441\u0442\u0438, \u043d\u043e \u0441\u043e\u0432\u0435\u0440\u0448\u0435\u043d\u043d\u043e \u043d\u0435\u0437\u0440\u0435\u043b \u0434\u043b\u044f \u0440\u0435\u0430\u043b\u044c\u043d\u044b\u0445 \u043f\u0440\u043e\u0435\u043a\u0442\u043e\u0432.  \u0414\u0430\u0439\u0442\u0435 \u043d\u0430\u043c \u0437\u043d\u0430\u0442\u044c, \u0447\u0442\u043e \u0432\u044b \u0434\u0443\u043c\u0430\u0435\u0442\u0435, \u0432 \u043d\u0430\u0448\u0435\u043c \u0434\u0438\u0441\u043a\u043e\u0440\u0434-\u043a\u0430\u043d\u0430\u043b\u0435.<\/p>\n<div class=\"dl-prod-book-inline-banner\">\n<div class=\"dl-prod-book-inline-banner__image gatsby-image-wrapper\" style=\"position:relative;overflow:hidden\"><img decoding=\"async\" aria-hidden=\"true\" src=\"data:image\/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAZCAYAAAAxFw7TAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFzklEQVQ4y1WU2VOb1xnGzyfVbt1O6tST4AVjglc2s1sskgxC8GlBIAmQkMRugtlBbAJjwAKMDYbYYOyMSTJt46YzSV1PO2k7uWimnfauF532or3oTC\/7L3R6+ev7SaGeXjxzzrf9znOe9\/2OUtm1qAs3UVl2tOybmLJlvGBDy7Kina9JyZRSNWaRKVPGzCq0c5WYDJ21oE6Xp5VRhlI5TtR79WmdrkG9UykPqmQuOlONelc+OCUvn5KX3y5B\/UD0drG8Z9wrRZ0sFri8a4DPiNRlF2aRAQyPrxNff87U2gFtw6tU+keIjCWJTawzvvKY2Y0Dhha2GVnaQY\/FCQ3d5dbsfTFQgSlLzIh7pa75UBdFWW5qe1fpTuzROb\/PVX2Y0uA0ocltAiMbtE9sEZ3ZJTi6SfPQBs7uRapDs\/iH1sStwM7ZUedtAsz1425uIhDw4vU6OVdSj8ftoMxWj8\/bwPVqB5lFdgotdhx1VtwNVoJuK9fK7ehOGzeqqrHW1lNgcXDich1KK2zn4wUPP7zj4XfbLg4THkZ7mni55ObpjJsv111M9Xn40xOdtqAH3ethc9TFyzs6S+\/r7Md1\/rDVwMsFJyfzpB7HSjqo8wXJqGynJ9aC2+9nsLuZtrbmFLjR18xkvwe\/30NvxEul04tN99Ag4NsxF5FWF\/W6Tle7zndzddlySSeqMIa6HkXlRVD5HaIwKleUJ7oYxJwXSs8vNcv9AOpKi8xFl+X6klGDJimqV649AizrRSvrQZV2YyrtRCuJoRVHMZfIAvltnKrpZ2f2EYVTEr61H5PAtII2TPmtaHkBNKmBltuCKVfgRoFVxQCq\/JaoH1XWh7GAAVeFEU44pjmdPGQucR9HZBLL4h6ZvoS4bJXn4rgglFpU5QXTzgWuVOUQyjKIuvE+R3BzWT9V\/juci39IxsAqnzbO0FQywFnnAGUrh1wJJMWlRHO94w24oD0NVzVjqOpRVNUIpqrhFNTqWmR5ag9\/KMFaU4KvJp9i8cxhzumg0NZLT\/QhVyvku6IIWpFEcz3yBq7scZRtSvKZQKsZF\/AYbk+cj1pn+bN+l78PPqY3usj3yjo4W9xLpkDHJp7SPvYzcSTAUsm\/uEskxS2S4qraedTNWdEMJvu0wOP01Y\/yr+ga\/+zZpcEhkRRIPvkBsotDXL0Wxj\/5gqEP\/yKuelI7Mkv+ZqO4BlzV30E5FlB1CUx1afi9lkW+GNgkw3FbYEGOF4c5WRqmqCjM3PkAv50\/YPfVP8gqHybTcpu3JH\/zUWGVvopqWEY575KCO5fIk9yOX5B\/M8fD8Ypuvl8RJV968UVBN\/+Ww4CvP+WvyedU3JjkRPUI\/1dY5d1AuaXHXPdIwRu\/gVtkVeOczGkkp7yTr9sm4bND+M0hf5u7z4B1ireMzK2i6rFUUZVRVNX8AOXbRDUJ2LMucGkJgWvG3LnIaWlW50U3fzx4xn9+\/3O+jK4SqZrnlMSkORJoRv5G9rKAqpkQoH8L1fKQFPgI3iJjYBtTaI9vh5+SaRujtGaQYNMyDlkkuy2JCm+jWh+ljegr\/K8WKZgBaN5MSZO56UjGvcADLi39iNatz7jSucV3\/Gt8yye78Bm7WMakG7nPCUxc1o4bW06DlO9+esuNRpbiQJexRa7DmxzresC127uUDz\/h3S5x1pRMO7LLb2g1YIvfdIp0iQHSDJhrjY6t1+grP6XjyS9p2X2NY\/MLYge\/5uba57Q+ek388Cv6nv0Kz84rIvu\/IPrgc2pnDjmmL6XaLr1lo8qpSic5G9sls2+P4AeviG3\/BHviCRVy7NfM7BBc2Wf00QtCy7s4kp9gufdj3uvbJVt+Q5PkmgYmDKC0jCeZVuNKKuCs4cf0LD\/j4foMQ4tTDC8MMZ24JadOHwtLg1intzgzsifZiSPb7BtYastHMKNdjFECvzoubjY+YmL1OcH5PYpHdsjo2uKdzodcuLVNSXyf7DEBSlG0+oU3sLo5\/gubMZbW7M93hgAAAABJRU5ErkJggg==\" alt=\"\u041a\u043d\u0438\u0433\u0430 \u00ab\u0413\u043b\u0443\u0431\u043e\u043a\u043e\u0435 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u0435 \u0432 \u043f\u0440\u043e\u0438\u0437\u0432\u043e\u0434\u0441\u0442\u0432\u0435\u00bb\" style=\"position:absolute;top:0;left:0;width:100%;height:100%;object-fit:contain;object-position:center;opacity:1;transition-delay:500ms\"\/><noscript><picture><source srcset=\"https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/69585\/deep-learning-book-cover.png 200w,&#10;https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/497c6\/deep-learning-book-cover.png 400w,&#10;https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/3c17d\/deep-learning-book-cover.png 720w\" sizes=\"(max-width: 720px) 100vw, 720px\"\/><img decoding=\"async\" loading=\"lazy\" sizes=\"(max-width: 720px) 100vw, 720px\" srcset=\"https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/69585\/deep-learning-book-cover.png 200w,&#10;https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/497c6\/deep-learning-book-cover.png 400w,&#10;https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/3c17d\/deep-learning-book-cover.png 720w\" src=\"https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/3c17d\/deep-learning-book-cover.png\" alt=\"\u041a\u043d\u0438\u0433\u0430 \u00ab\u0413\u043b\u0443\u0431\u043e\u043a\u043e\u0435 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u0435 \u0432 \u043f\u0440\u043e\u0438\u0437\u0432\u043e\u0434\u0441\u0442\u0432\u0435\u00bb\" style=\"position:absolute;top:0;left:0;opacity:1;width:100%;height:100%;object-fit:cover;object-position:center\"\/><\/picture><\/noscript><\/div>\n<div class=\"dl-prod-book-inline-banner__text\">\n<h2>\u041a\u043d\u0438\u0433\u0430 \u00ab\u0413\u043b\u0443\u0431\u043e\u043a\u043e\u0435 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u0435 \u0432 \u043f\u0440\u043e\u0438\u0437\u0432\u043e\u0434\u0441\u0442\u0432\u0435\u00bb \ud83d\udcd6<\/h2>\n<h4>\u0423\u0437\u043d\u0430\u0439\u0442\u0435, \u043a\u0430\u043a \u0441\u043e\u0437\u0434\u0430\u0432\u0430\u0442\u044c, \u043e\u0431\u0443\u0447\u0430\u0442\u044c, \u0440\u0430\u0437\u0432\u0435\u0440\u0442\u044b\u0432\u0430\u0442\u044c, \u043c\u0430\u0441\u0448\u0442\u0430\u0431\u0438\u0440\u043e\u0432\u0430\u0442\u044c \u0438 \u043f\u043e\u0434\u0434\u0435\u0440\u0436\u0438\u0432\u0430\u0442\u044c \u043c\u043e\u0434\u0435\u043b\u0438 \u0433\u043b\u0443\u0431\u043e\u043a\u043e\u0433\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f.  \u0418\u0437\u0443\u0447\u0438\u0442\u0435 \u0438\u043d\u0444\u0440\u0430\u0441\u0442\u0440\u0443\u043a\u0442\u0443\u0440\u0443 \u043c\u0430\u0448\u0438\u043d\u043d\u043e\u0433\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f \u0438 MLOps \u043d\u0430 \u043f\u0440\u0430\u043a\u0442\u0438\u0447\u0435\u0441\u043a\u0438\u0445 \u043f\u0440\u0438\u043c\u0435\u0440\u0430\u0445.<\/h4>\n<p>\u0423\u0437\u043d\u0430\u0442\u044c \u0431\u043e\u043b\u044c\u0448\u0435<\/p><\/div>\n<\/div>\n<p><em class=\"affiliate-disclosure\">* \u0420\u0430\u0441\u043a\u0440\u044b\u0442\u0438\u0435 \u0438\u043d\u0444\u043e\u0440\u043c\u0430\u0446\u0438\u0438: \u043e\u0431\u0440\u0430\u0442\u0438\u0442\u0435 \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435, \u0447\u0442\u043e \u043d\u0435\u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u0438\u0437 \u043f\u0440\u0438\u0432\u0435\u0434\u0435\u043d\u043d\u044b\u0445 \u0432\u044b\u0448\u0435 \u0441\u0441\u044b\u043b\u043e\u043a \u043c\u043e\u0433\u0443\u0442 \u0431\u044b\u0442\u044c \u043f\u0430\u0440\u0442\u043d\u0435\u0440\u0441\u043a\u0438\u043c\u0438 \u0441\u0441\u044b\u043b\u043a\u0430\u043c\u0438, \u0438 \u043c\u044b \u0431\u0435\u0437 \u0434\u043e\u043f\u043e\u043b\u043d\u0438\u0442\u0435\u043b\u044c\u043d\u044b\u0445 \u0437\u0430\u0442\u0440\u0430\u0442 \u0434\u043b\u044f \u0432\u0430\u0441 \u043f\u043e\u043b\u0443\u0447\u0438\u043c \u043a\u043e\u043c\u0438\u0441\u0441\u0438\u044e, \u0435\u0441\u043b\u0438 \u0432\u044b \u0440\u0435\u0448\u0438\u0442\u0435 \u0441\u043e\u0432\u0435\u0440\u0448\u0438\u0442\u044c \u043f\u043e\u043a\u0443\u043f\u043a\u0443 \u043f\u043e\u0441\u043b\u0435 \u043f\u0435\u0440\u0435\u0445\u043e\u0434\u0430 \u043f\u043e \u0441\u0441\u044b\u043b\u043a\u0435.<\/em><\/p>\n<\/div>\n","protected":false},"excerpt":{"rendered":"<p>\u0412 \u044d\u0442\u043e\u043c \u0440\u0443\u043a\u043e\u0432\u043e\u0434\u0441\u0442\u0432\u0435 \u043c\u044b \u0440\u0430\u0441\u0441\u043c\u043e\u0442\u0440\u0438\u043c, \u043a\u0430\u043a \u0440\u0430\u0437\u0440\u0430\u0431\u043e\u0442\u0430\u0442\u044c \u043d\u0435\u0439\u0440\u043e\u043d\u043d\u0443\u044e \u0441\u0435\u0442\u044c (NN) \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e JAX. \u0418 \u043a\u0430\u043a\u0443\u044e \u043b\u0443\u0447\u0448\u0435 \u043c\u043e\u0434\u0435\u043b\u044c \u0432\u044b\u0431\u0440\u0430\u0442\u044c, \u0447\u0435\u043c \u0422\u0440\u0430\u043d\u0441\u0444\u043e\u0440\u043c\u0435\u0440. \u041f\u043e \u043c\u0435\u0440\u0435 \u0440\u043e\u0441\u0442\u0430 \u043f\u043e\u043f\u0443\u043b\u044f\u0440\u043d\u043e\u0441\u0442\u0438 JAX \u0432\u0441\u0435 \u0431\u043e\u043b\u044c\u0448\u0435 \u0438 \u0431\u043e\u043b\u044c\u0448\u0435 \u043a\u043e\u043c\u0430\u043d\u0434 \u0440\u0430\u0437\u0440\u0430\u0431\u043e\u0442\u0447\u0438\u043a\u043e\u0432 \u043d\u0430\u0447\u0438\u043d\u0430\u044e\u0442 \u044d\u043a\u0441\u043f\u0435\u0440\u0438\u043c\u0435\u043d\u0442\u0438\u0440\u043e\u0432\u0430\u0442\u044c \u0441 \u043d\u0438\u043c \u0438 \u0432\u043a\u043b\u044e\u0447\u0430\u0442\u044c \u0435\u0433\u043e \u0432 \u0441\u0432\u043e\u0438 \u043f\u0440\u043e\u0435\u043a\u0442\u044b. \u041d\u0435\u0441\u043c\u043e\u0442\u0440\u044f \u043d\u0430 \u0442\u043e, \u0447\u0442\u043e \u0435\u043c\u0443 \u043d\u0435 \u0445\u0432\u0430\u0442\u0430\u0435\u0442 \u0437\u0440\u0435\u043b\u043e\u0441\u0442\u0438 Tensorflow \u0438\u043b\u0438 Pytorch, \u043e\u043d \u043f\u0440\u0435\u0434\u043e\u0441\u0442\u0430\u0432\u043b\u044f\u0435\u0442 \u043d\u0435\u043a\u043e\u0442\u043e\u0440\u044b\u0435 [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":1280,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[1],"tags":[],"class_list":{"0":"post-1279","1":"post","2":"type-post","3":"status-publish","4":"format-standard","5":"has-post-thumbnail","7":"category-ai-research-and-news"},"_links":{"self":[{"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/posts\/1279","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/gptmain.news\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=1279"}],"version-history":[{"count":0,"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/posts\/1279\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/media\/1280"}],"wp:attachment":[{"href":"https:\/\/gptmain.news\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=1279"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/gptmain.news\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=1279"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/gptmain.news\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=1279"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}