{"id":1267,"date":"2023-08-13T04:41:38","date_gmt":"2023-08-13T04:41:38","guid":{"rendered":"https:\/\/www.gptmain.news\/?p=1267"},"modified":"2023-08-13T04:41:38","modified_gmt":"2023-08-13T04:41:38","slug":"jax-%d0%bf%d1%80%d0%be%d1%82%d0%b8%d0%b2-tensorflow-%d0%bf%d1%80%d0%be%d1%82%d0%b8%d0%b2-pytorch-%d1%81%d0%be%d0%b7%d0%b4%d0%b0%d0%bd%d0%b8%d0%b5-%d0%b2%d0%b0%d1%80%d0%b8%d0%b0%d1%86%d0%b8%d0%be","status":"publish","type":"post","link":"https:\/\/gptmain.news\/?p=1267","title":{"rendered":"JAX \u043f\u0440\u043e\u0442\u0438\u0432 Tensorflow \u043f\u0440\u043e\u0442\u0438\u0432 Pytorch: \u0441\u043e\u0437\u0434\u0430\u043d\u0438\u0435 \u0432\u0430\u0440\u0438\u0430\u0446\u0438\u043e\u043d\u043d\u043e\u0433\u043e \u0430\u0432\u0442\u043e\u044d\u043d\u043a\u043e\u0434\u0435\u0440\u0430 (VAE)\n | GPTMain News"},"content":{"rendered":"<div id=\"\">\n<p>\u041c\u043d\u0435 \u0431\u044b\u043b\u043e \u043e\u0447\u0435\u043d\u044c \u043b\u044e\u0431\u043e\u043f\u044b\u0442\u043d\u043e \u043f\u043e\u0441\u043c\u043e\u0442\u0440\u0435\u0442\u044c, \u043a\u0430\u043a JAX \u0441\u0440\u0430\u0432\u043d\u0438\u0432\u0430\u0435\u0442\u0441\u044f \u0441 Pytorch \u0438\u043b\u0438 Tensorflow.  \u042f \u0440\u0435\u0448\u0438\u043b, \u0447\u0442\u043e \u043b\u0443\u0447\u0448\u0438\u0439 \u0441\u043f\u043e\u0441\u043e\u0431 \u0434\u043b\u044f \u043a\u043e\u0433\u043e-\u0442\u043e \u0441\u0440\u0430\u0432\u043d\u0438\u0442\u044c \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a\u0438 \u2014 \u0441\u043e\u0437\u0434\u0430\u0442\u044c \u043e\u0434\u043d\u043e \u0438 \u0442\u043e \u0436\u0435 \u0441 \u043d\u0443\u043b\u044f \u0432 \u043e\u0431\u043e\u0438\u0445 \u0438\u0437 \u043d\u0438\u0445.  \u0418 \u044d\u0442\u043e \u0438\u043c\u0435\u043d\u043d\u043e \u0442\u043e, \u0447\u0442\u043e \u044f \u0441\u0434\u0435\u043b\u0430\u043b.  \u0412 \u044d\u0442\u043e\u0439 \u0441\u0442\u0430\u0442\u044c\u0435 \u044f \u0440\u0430\u0437\u0440\u0430\u0431\u0430\u0442\u044b\u0432\u0430\u044e \u0432\u0430\u0440\u0438\u0430\u0446\u0438\u043e\u043d\u043d\u044b\u0439 \u0430\u0432\u0442\u043e\u044d\u043d\u043a\u043e\u0434\u0435\u0440 \u0441 JAX, Tensorflow \u0438 Pytorch \u043e\u0434\u043d\u043e\u0432\u0440\u0435\u043c\u0435\u043d\u043d\u043e.  \u042f \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u044e \u043a\u043e\u0434 \u0434\u043b\u044f \u043a\u0430\u0436\u0434\u043e\u0433\u043e \u043a\u043e\u043c\u043f\u043e\u043d\u0435\u043d\u0442\u0430 \u0440\u044f\u0434\u043e\u043c, \u0447\u0442\u043e\u0431\u044b \u043d\u0430\u0439\u0442\u0438 \u0440\u0430\u0437\u043b\u0438\u0447\u0438\u044f, \u0441\u0445\u043e\u0434\u0441\u0442\u0432\u0430, \u0441\u043b\u0430\u0431\u044b\u0435 \u0438 \u0441\u0438\u043b\u044c\u043d\u044b\u0435 \u0441\u0442\u043e\u0440\u043e\u043d\u044b.<\/p>\n<p>\u041d\u0430\u0447\u043d\u0435\u043c?<\/p>\n<h2 id=\"prologue\">\u041f\u0440\u043e\u043b\u043e\u0433<\/h2>\n<p>\u041d\u0435\u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u0432\u0435\u0449\u0438, \u043d\u0430 \u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u0441\u043b\u0435\u0434\u0443\u0435\u0442 \u043e\u0431\u0440\u0430\u0442\u0438\u0442\u044c \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435, \u043f\u0440\u0435\u0436\u0434\u0435 \u0447\u0435\u043c \u043c\u044b \u043f\u0440\u0438\u0441\u0442\u0443\u043f\u0438\u043c \u043a \u0438\u0437\u0443\u0447\u0435\u043d\u0438\u044e \u043a\u043e\u0434\u0430:<\/p>\n<ul>\n<li>\n<p>\u042f \u0431\u0443\u0434\u0443 \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u0442\u044c Flax \u043f\u043e\u0432\u0435\u0440\u0445 JAX \u2014 \u0431\u0438\u0431\u043b\u0438\u043e\u0442\u0435\u043a\u0438 \u0434\u043b\u044f \u043d\u0435\u0439\u0440\u043e\u043d\u043d\u044b\u0445 \u0441\u0435\u0442\u0435\u0439, \u0440\u0430\u0437\u0440\u0430\u0431\u043e\u0442\u0430\u043d\u043d\u043e\u0439 Google.  \u041e\u043d \u0441\u043e\u0434\u0435\u0440\u0436\u0438\u0442 \u043c\u043d\u043e\u0436\u0435\u0441\u0442\u0432\u043e \u0433\u043e\u0442\u043e\u0432\u044b\u0445 \u043a \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u043d\u0438\u044e \u043c\u043e\u0434\u0443\u043b\u0435\u0439 \u0433\u043b\u0443\u0431\u043e\u043a\u043e\u0433\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f, \u0441\u043b\u043e\u0435\u0432, \u0444\u0443\u043d\u043a\u0446\u0438\u0439 \u0438 \u043e\u043f\u0435\u0440\u0430\u0446\u0438\u0439.<\/p>\n<\/li>\n<li>\n<p>\u0414\u043b\u044f \u0440\u0435\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u0438 Tensorflow \u044f \u0431\u0443\u0434\u0443 \u043f\u043e\u043b\u0430\u0433\u0430\u0442\u044c\u0441\u044f \u043d\u0430 \u0430\u0431\u0441\u0442\u0440\u0430\u043a\u0446\u0438\u0438 Keras.<\/p>\n<\/li>\n<li>\n<p>\u0414\u043b\u044f Pytorch \u044f \u0431\u0443\u0434\u0443 \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u0442\u044c \u0441\u0442\u0430\u043d\u0434\u0430\u0440\u0442\u043d\u044b\u0439 <code>nn.module<\/code>.<\/p>\n<\/li>\n<\/ul>\n<p>\u041f\u043e\u0441\u043a\u043e\u043b\u044c\u043a\u0443 \u0431\u043e\u043b\u044c\u0448\u0438\u043d\u0441\u0442\u0432\u043e \u0438\u0437 \u043d\u0430\u0441 \u043d\u0435\u043c\u043d\u043e\u0433\u043e \u0437\u043d\u0430\u043a\u043e\u043c\u044b \u0441 Tensorflow \u0438 Pytorch, \u043c\u044b \u0443\u0434\u0435\u043b\u0438\u043c \u0431\u043e\u043b\u044c\u0448\u0435 \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u044f JAX \u0438 Flax.  \u0412\u043e\u0442 \u043f\u043e\u0447\u0435\u043c\u0443 \u044f \u0431\u0443\u0434\u0443 \u043e\u0431\u044a\u044f\u0441\u043d\u044f\u0442\u044c \u0432\u0435\u0449\u0438 \u043f\u043e \u043f\u0443\u0442\u0438, \u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u043c\u043e\u0433\u0443\u0442 \u0431\u044b\u0442\u044c \u043d\u0435\u0437\u043d\u0430\u043a\u043e\u043c\u044b \u043c\u043d\u043e\u0433\u0438\u043c.  \u0422\u0430\u043a \u0447\u0442\u043e \u0432\u044b \u043c\u043e\u0436\u0435\u0442\u0435 \u0440\u0430\u0441\u0441\u043c\u0430\u0442\u0440\u0438\u0432\u0430\u0442\u044c \u044d\u0442\u0443 \u0441\u0442\u0430\u0442\u044c\u044e \u043a\u0430\u043a \u043b\u0435\u0433\u043a\u043e\u0435 \u0440\u0443\u043a\u043e\u0432\u043e\u0434\u0441\u0442\u0432\u043e \u043f\u043e Flax.<\/p>\n<p>\u041a\u0440\u043e\u043c\u0435 \u0442\u043e\u0433\u043e, \u044f \u043f\u0440\u0435\u0434\u043f\u043e\u043b\u0430\u0433\u0430\u044e, \u0447\u0442\u043e \u0432\u044b \u0437\u043d\u0430\u043a\u043e\u043c\u044b \u0441 \u043e\u0441\u043d\u043e\u0432\u043d\u044b\u043c\u0438 \u043f\u0440\u0438\u043d\u0446\u0438\u043f\u0430\u043c\u0438 VAE.  \u0415\u0441\u043b\u0438 \u043d\u0435\u0442, \u0442\u043e \u043c\u043e\u0436\u0435\u0442\u0435 \u043f\u043e\u0441\u043e\u0432\u0435\u0442\u043e\u0432\u0430\u0442\u044c \u043c\u043e\u044e \u043f\u0440\u0435\u0434\u044b\u0434\u0443\u0449\u0443\u044e \u0441\u0442\u0430\u0442\u044c\u044e \u043e \u043c\u043e\u0434\u0435\u043b\u044f\u0445 \u0441\u043a\u0440\u044b\u0442\u044b\u0445 \u043f\u0435\u0440\u0435\u043c\u0435\u043d\u043d\u044b\u0445.  \u0415\u0441\u043b\u0438 \u0432\u0440\u043e\u0434\u0435 \u0432\u0441\u0435 \u043f\u043e\u043d\u044f\u0442\u043d\u043e, \u043f\u0440\u043e\u0434\u043e\u043b\u0436\u0438\u043c.<\/p>\n<p><strong>\u041a\u0440\u0430\u0442\u043a\u043e\u0435 \u0440\u0435\u0437\u044e\u043c\u0435<\/strong>: \u0412\u0430\u043d\u0438\u043b\u044c\u043d\u044b\u0439 \u0430\u0432\u0442\u043e\u044d\u043d\u043a\u043e\u0434\u0435\u0440 \u0441\u043e\u0441\u0442\u043e\u0438\u0442 \u0438\u0437 \u043a\u043e\u0434\u0438\u0440\u043e\u0432\u0449\u0438\u043a\u0430 \u0438 \u0434\u0435\u043a\u043e\u0434\u0435\u0440\u0430.  \u041a\u043e\u0434\u0435\u0440 \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u0443\u0435\u0442 \u0432\u0432\u043e\u0434 \u0432 \u0441\u043a\u0440\u044b\u0442\u043e\u0435 \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u0438\u0435 <span class=\"inlineMath\"><span class=\"katex\"><span class=\"katex-mathml\"><math xmlns=\"http:\/\/www.w3.org\/1998\/Math\/MathML\"><semantics><mrow><mi>\u0433<\/mi><\/mrow><annotation encoding=\"application\/x-tex\">\u0433<\/annotation><\/semantics><\/math><\/span><span class=\"katex-html\" aria-hidden=\"true\"><span class=\"base\"><span class=\"strut\" style=\"height:0.43056em;vertical-align:0em\"\/><span class=\"mord mathnormal\" style=\"margin-right:0.04398em\">\u0433<\/span><\/span><\/span><\/span><\/span>  \u0438 \u0434\u0435\u043a\u043e\u0434\u0435\u0440 \u043f\u044b\u0442\u0430\u0435\u0442\u0441\u044f \u0432\u043e\u0441\u0441\u0442\u0430\u043d\u043e\u0432\u0438\u0442\u044c \u0432\u0432\u043e\u0434 \u043d\u0430 \u043e\u0441\u043d\u043e\u0432\u0435 \u044d\u0442\u043e\u0433\u043e \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u0438\u044f.  \u0412 \u0432\u0430\u0440\u0438\u0430\u0446\u0438\u043e\u043d\u043d\u044b\u0445 \u0430\u0432\u0442\u043e\u044d\u043d\u043a\u043e\u0434\u0435\u0440\u0430\u0445 \u043a \u0441\u043c\u0435\u0441\u0438 \u0442\u0430\u043a\u0436\u0435 \u0434\u043e\u0431\u0430\u0432\u043b\u044f\u0435\u0442\u0441\u044f \u0441\u0442\u043e\u0445\u0430\u0441\u0442\u0438\u0447\u043d\u043e\u0441\u0442\u044c \u0432 \u0442\u043e\u043c \u0441\u043c\u044b\u0441\u043b\u0435, \u0447\u0442\u043e \u0441\u043a\u0440\u044b\u0442\u043e\u0435 \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u0438\u0435 \u043e\u0431\u0435\u0441\u043f\u0435\u0447\u0438\u0432\u0430\u0435\u0442 \u0440\u0430\u0441\u043f\u0440\u0435\u0434\u0435\u043b\u0435\u043d\u0438\u0435 \u0432\u0435\u0440\u043e\u044f\u0442\u043d\u043e\u0441\u0442\u0435\u0439.  \u042d\u0442\u043e \u043f\u0440\u043e\u0438\u0441\u0445\u043e\u0434\u0438\u0442 \u0441 \u0442\u0440\u044e\u043a\u043e\u043c \u0440\u0435\u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u0438\u0437\u0430\u0446\u0438\u0438.<\/p>\n<p><span class=\"gatsby-resp-image-wrapper\" style=\"position:relative;display:block;margin-left:auto;margin-right:auto;max-width:1200px\"><\/p>\n<p>    <span class=\"gatsby-resp-image-background-image\" style=\"padding-bottom:16%;position:relative;bottom:0;left:0;background-image:url('data:image\/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAADCAYAAACTWi8uAAAACXBIWXMAAAsTAAALEwEAmpwYAAAA\/ElEQVQI12PIrjGtmDYlprptomcdAwMD49t722M+PNr99tuLg7e+Ptvj8fLuTob\/\/\/+zAzHzz\/\/\/GV7d2br4y\/MDr74+37\/r\/4fDwiA92TXGUX2TQlPaprgnMfT2R\/5f1pv+39VH4T9QUubFze2lH5\/u\/\/\/z3Yn\/P94cngc0iGHf6bt39p99uBTEfnJjy\/kfb44B5Y7+\/\/\/+lBdQD8eUyYm\/lvak\/Xfxk\/\/PUFlpf6Ck2OqonafcMaCk4Ju7O\/xBBn56fvj\/r7dHp\/7\/c4th78nbhfvP3Av99eMNw4vbW5d8eXEYKH\/w\/\/8f512BephaGj0vlhXZvDB1kngEAH1WlTwkQKjnAAAAAElFTkSuQmCC');background-size:cover;display:block\"\/><br \/>\n  <img decoding=\"async\" class=\"gatsby-resp-image-image\" alt=\"\u0432\u0430\u044d\" title=\"\u0432\u0430\u044d\" src=\"https:\/\/theaisummer.com\/static\/f93cd2d985a4f49829e9e8a0ed5feea5\/c1b63\/vae.png\" srcset=\"\/static\/f93cd2d985a4f49829e9e8a0ed5feea5\/5a46d\/vae.png 300w,\/static\/f93cd2d985a4f49829e9e8a0ed5feea5\/0a47e\/vae.png 600w,\/static\/f93cd2d985a4f49829e9e8a0ed5feea5\/c1b63\/vae.png 1200w,\/static\/f93cd2d985a4f49829e9e8a0ed5feea5\/d61c2\/vae.png 1800w,\/static\/f93cd2d985a4f49829e9e8a0ed5feea5\/97a96\/vae.png 2400w,\/static\/f93cd2d985a4f49829e9e8a0ed5feea5\/eb645\/vae.png 2500w\" sizes=\"(max-width: 1200px) 100vw, 1200px\" style=\"width:100%;height:100%;margin:0;vertical-align:middle;position:absolute;top:0;left:0\" loading=\"lazy\"\/><\/p>\n<p>    <\/span><br \/>\n<em>\u0418\u0437\u043e\u0431\u0440\u0430\u0436\u0435\u043d\u0438\u0435 \u0430\u0432\u0442\u043e\u0440\u0430<\/em><\/p>\n<h2 id=\"the-encoder\">\u041a\u043e\u0434\u0435\u0440<\/h2>\n<p>\u0414\u043b\u044f \u043a\u043e\u0434\u0438\u0440\u043e\u0432\u0449\u0438\u043a\u0430 \u043f\u0440\u043e\u0441\u0442\u043e\u0433\u043e \u043b\u0438\u043d\u0435\u0439\u043d\u043e\u0433\u043e \u0441\u043b\u043e\u044f \u0441 \u043f\u043e\u0441\u043b\u0435\u0434\u0443\u044e\u0449\u0435\u0439 \u0430\u043a\u0442\u0438\u0432\u0430\u0446\u0438\u0435\u0439 RELU \u0434\u043e\u043b\u0436\u043d\u043e \u0431\u044b\u0442\u044c \u0434\u043e\u0441\u0442\u0430\u0442\u043e\u0447\u043d\u043e \u0434\u043b\u044f \u0438\u0433\u0440\u0443\u0448\u0435\u0447\u043d\u043e\u0433\u043e \u043f\u0440\u0438\u043c\u0435\u0440\u0430.  \u0420\u0435\u0437\u0443\u043b\u044c\u0442\u0430\u0442\u043e\u043c \u0441\u043b\u043e\u044f \u0431\u0443\u0434\u0435\u0442 \u043a\u0430\u043a \u0441\u0440\u0435\u0434\u043d\u0435\u0435 \u0437\u043d\u0430\u0447\u0435\u043d\u0438\u0435, \u0442\u0430\u043a \u0438 \u0441\u0442\u0430\u043d\u0434\u0430\u0440\u0442\u043d\u043e\u0435 \u043e\u0442\u043a\u043b\u043e\u043d\u0435\u043d\u0438\u0435 \u0440\u0430\u0441\u043f\u0440\u0435\u0434\u0435\u043b\u0435\u043d\u0438\u044f \u0432\u0435\u0440\u043e\u044f\u0442\u043d\u043e\u0441\u0442\u0435\u0439.<\/p>\n<p>\u041e\u0441\u043d\u043e\u0432\u043d\u044b\u043c \u0441\u0442\u0440\u043e\u0438\u0442\u0435\u043b\u044c\u043d\u044b\u043c \u0431\u043b\u043e\u043a\u043e\u043c Flax API \u044f\u0432\u043b\u044f\u0435\u0442\u0441\u044f <code>Module<\/code> \u0430\u0431\u0441\u0442\u0440\u0430\u043a\u0446\u0438\u044f, \u043a\u043e\u0442\u043e\u0440\u0443\u044e \u043c\u044b \u0431\u0443\u0434\u0435\u043c \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u0442\u044c \u0434\u043b\u044f \u0440\u0435\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u0438 \u043d\u0430\u0448\u0435\u0433\u043e \u043a\u043e\u0434\u0438\u0440\u043e\u0432\u0449\u0438\u043a\u0430 \u0432 JAX. <code>module<\/code> \u044f\u0432\u043b\u044f\u0435\u0442\u0441\u044f \u0447\u0430\u0441\u0442\u044c\u044e <code>linen<\/code> \u043f\u043e\u0434\u043f\u0430\u043a\u0435\u0442.  \u041f\u043e\u0445\u043e\u0436 \u043d\u0430 Pytorch <code>nn.module<\/code>, \u043d\u0430\u043c \u0441\u043d\u043e\u0432\u0430 \u043d\u0443\u0436\u043d\u043e \u043e\u043f\u0440\u0435\u0434\u0435\u043b\u0438\u0442\u044c \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u044b \u043d\u0430\u0448\u0435\u0433\u043e \u043a\u043b\u0430\u0441\u0441\u0430.  \u0412 Pytorch \u043c\u044b \u043f\u0440\u0438\u0432\u044b\u043a\u043b\u0438 \u043e\u0431\u044a\u044f\u0432\u043b\u044f\u0442\u044c \u0438\u0445 \u0432\u043d\u0443\u0442\u0440\u0438 <code>__init__<\/code> \u0444\u0443\u043d\u043a\u0446\u0438\u044e \u0438 \u0440\u0435\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u044e \u043f\u0440\u044f\u043c\u043e\u0433\u043e \u043f\u0440\u043e\u0445\u043e\u0434\u0430 \u0432\u043d\u0443\u0442\u0440\u0438 <code>forward<\/code> \u043c\u0435\u0442\u043e\u0434.  \u0412\u043e \u043b\u044c\u043d\u0435 \u0434\u0435\u043b\u0430 \u043e\u0431\u0441\u0442\u043e\u044f\u0442 \u043d\u0435\u043c\u043d\u043e\u0433\u043e \u0438\u043d\u0430\u0447\u0435.  \u0410\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u044b \u043e\u043f\u0440\u0435\u0434\u0435\u043b\u044f\u044e\u0442\u0441\u044f \u043b\u0438\u0431\u043e \u043a\u0430\u043a \u0430\u0442\u0440\u0438\u0431\u0443\u0442\u044b \u043a\u043b\u0430\u0441\u0441\u0430 \u0434\u0430\u043d\u043d\u044b\u0445, \u043b\u0438\u0431\u043e \u043a\u0430\u043a \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u044b \u043c\u0435\u0442\u043e\u0434\u0430.  \u041e\u0431\u044b\u0447\u043d\u043e \u0444\u0438\u043a\u0441\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u044b\u0435 \u0441\u0432\u043e\u0439\u0441\u0442\u0432\u0430 \u043e\u043f\u0440\u0435\u0434\u0435\u043b\u044f\u044e\u0442\u0441\u044f \u043a\u0430\u043a \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u044b \u043a\u043b\u0430\u0441\u0441\u0430 \u0434\u0430\u043d\u043d\u044b\u0445, \u0430 \u0434\u0438\u043d\u0430\u043c\u0438\u0447\u0435\u0441\u043a\u0438\u0435 \u0441\u0432\u043e\u0439\u0441\u0442\u0432\u0430 \u2014 \u043a\u0430\u043a \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u044b \u043c\u0435\u0442\u043e\u0434\u0430.  \u0422\u0430\u043a\u0436\u0435 \u0432\u043c\u0435\u0441\u0442\u043e \u0440\u0435\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u0438 <code>forward<\/code> \u043c\u0435\u0442\u043e\u0434, \u043c\u044b \u0440\u0435\u0430\u043b\u0438\u0437\u0443\u0435\u043c <code>__call__<\/code><\/p>\n<blockquote>\n<p>\u041c\u043e\u0434\u0443\u043b\u044c Dataclass \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d \u0432 Python 3.7 \u043a\u0430\u043a \u0441\u043b\u0443\u0436\u0435\u0431\u043d\u044b\u0439 \u0438\u043d\u0441\u0442\u0440\u0443\u043c\u0435\u043d\u0442 \u0434\u043b\u044f \u0441\u043e\u0437\u0434\u0430\u043d\u0438\u044f \u0441\u0442\u0440\u0443\u043a\u0442\u0443\u0440\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u044b\u0445 \u043a\u043b\u0430\u0441\u0441\u043e\u0432, \u0441\u043f\u0435\u0446\u0438\u0430\u043b\u044c\u043d\u043e \u043f\u0440\u0435\u0434\u043d\u0430\u0437\u043d\u0430\u0447\u0435\u043d\u043d\u044b\u0445 \u0434\u043b\u044f \u0445\u0440\u0430\u043d\u0435\u043d\u0438\u044f \u0434\u0430\u043d\u043d\u044b\u0445.  \u042d\u0442\u0438 \u043a\u043b\u0430\u0441\u0441\u044b \u0441\u043e\u0434\u0435\u0440\u0436\u0430\u0442 \u043e\u043f\u0440\u0435\u0434\u0435\u043b\u0435\u043d\u043d\u044b\u0435 \u0441\u0432\u043e\u0439\u0441\u0442\u0432\u0430 \u0438 \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u0434\u043b\u044f \u0440\u0430\u0431\u043e\u0442\u044b \u0441 \u0434\u0430\u043d\u043d\u044b\u043c\u0438 \u0438 \u0438\u0445 \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u0438\u0435\u043c.  \u041e\u043d\u0438 \u0442\u0430\u043a\u0436\u0435 \u0441\u043e\u043a\u0440\u0430\u0449\u0430\u044e\u0442 \u043e\u0431\u044a\u0435\u043c \u0448\u0430\u0431\u043b\u043e\u043d\u043d\u043e\u0433\u043e \u043a\u043e\u0434\u0430 \u043f\u043e \u0441\u0440\u0430\u0432\u043d\u0435\u043d\u0438\u044e \u0441 \u043e\u0431\u044b\u0447\u043d\u044b\u043c\u0438 \u043a\u043b\u0430\u0441\u0441\u0430\u043c\u0438.<\/p>\n<\/blockquote>\n<p>\u0418\u0442\u0430\u043a, \u0447\u0442\u043e\u0431\u044b \u0441\u043e\u0437\u0434\u0430\u0442\u044c \u043d\u043e\u0432\u044b\u0439 \u043c\u043e\u0434\u0443\u043b\u044c \u0432\u043e Flax, \u043d\u0430\u043c \u043d\u0443\u0436\u043d\u043e:<\/p>\n<ul>\n<li>\n<p>\u0418\u043d\u0438\u0446\u0438\u0430\u043b\u0438\u0437\u0438\u0440\u043e\u0432\u0430\u0442\u044c \u043a\u043b\u0430\u0441\u0441, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u043d\u0430\u0441\u043b\u0435\u0434\u0443\u0435\u0442 <code>flax.linen.nn.Module<\/code><\/p>\n<\/li>\n<li>\n<p>\u041e\u043f\u0440\u0435\u0434\u0435\u043b\u0438\u0442\u0435 \u0441\u0442\u0430\u0442\u0438\u0447\u0435\u0441\u043a\u0438\u0435 \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u044b \u043a\u0430\u043a \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u044b \u043a\u043b\u0430\u0441\u0441\u0430 \u0434\u0430\u043d\u043d\u044b\u0445<\/p>\n<\/li>\n<li>\n<p>\u0420\u0435\u0430\u043b\u0438\u0437\u0443\u0439\u0442\u0435 \u043f\u0440\u044f\u043c\u043e\u0439 \u043f\u0440\u043e\u0445\u043e\u0434 \u0432\u043d\u0443\u0442\u0440\u0438 <code>__call_<\/code> \u043c\u0435\u0442\u043e\u0434.<\/p>\n<\/li>\n<\/ul>\n<p>\u0427\u0442\u043e\u0431\u044b \u0441\u0432\u044f\u0437\u0430\u0442\u044c \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u044b \u0441 \u043c\u043e\u0434\u0435\u043b\u044c\u044e \u0438 \u0438\u043c\u0435\u0442\u044c \u0432\u043e\u0437\u043c\u043e\u0436\u043d\u043e\u0441\u0442\u044c \u043e\u043f\u0440\u0435\u0434\u0435\u043b\u044f\u0442\u044c \u043f\u043e\u0434\u043c\u043e\u0434\u0443\u043b\u0438 \u043d\u0435\u043f\u043e\u0441\u0440\u0435\u0434\u0441\u0442\u0432\u0435\u043d\u043d\u043e \u0432\u043d\u0443\u0442\u0440\u0438 \u043c\u043e\u0434\u0443\u043b\u044f, \u043d\u0430\u043c \u0442\u0430\u043a\u0436\u0435 \u043d\u0435\u043e\u0431\u0445\u043e\u0434\u0438\u043c\u043e \u0430\u043d\u043d\u043e\u0442\u0438\u0440\u043e\u0432\u0430\u0442\u044c <code>__call__<\/code> \u043c\u0435\u0442\u043e\u0434 \u0441 <code>@nn.compact<\/code>.<\/p>\n<p>\u041e\u0431\u0440\u0430\u0442\u0438\u0442\u0435 \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435, \u0447\u0442\u043e \u0432\u043c\u0435\u0441\u0442\u043e \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u043d\u0438\u044f \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u043e\u0432 \u043a\u043b\u0430\u0441\u0441\u0430 \u0434\u0430\u043d\u043d\u044b\u0445 \u0438 <code>@nn.compact<\/code> \u0430\u043d\u043d\u043e\u0442\u0430\u0446\u0438\u044e, \u043c\u044b \u043c\u043e\u0433\u043b\u0438 \u0431\u044b \u043e\u0431\u044a\u044f\u0432\u0438\u0442\u044c \u0432\u0441\u0435 \u0430\u0440\u0433\u0443\u043c\u0435\u043d\u0442\u044b \u0432\u043d\u0443\u0442\u0440\u0438 <code>setup<\/code> \u0442\u043e\u0447\u043d\u043e \u0442\u0430\u043a \u0436\u0435, \u043a\u0430\u043a \u0432 Pytorch \u0438\u043b\u0438 Tensorflow. <code>__init__<\/code>.<\/p>\n<div class=\"tabs\">\n<div class=\"tabs__content\">\n<div class=\"tabs__content__item--active\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> numpy <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">as<\/span><span class=\"token plain\"> np<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> jax<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">numpy <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">as<\/span><span class=\"token plain\"> jnp<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">from<\/span><span class=\"token plain\"> jax <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> random<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">from<\/span><span class=\"token plain\"> flax <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> linen <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">as<\/span><span class=\"token plain\"> nn<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">from<\/span><span class=\"token plain\"> flax <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> optim<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">Encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Module<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> latents<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">@nn<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">compact<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__call__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">500<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> name<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'fc1'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">relu<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   mean_x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">latents<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> name<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'fc2_mean'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   logvar_x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">latents<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> name<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'fc2_logvar'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> mean_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar_x<\/span><\/p><\/pre>\n<\/div>\n<div class=\"tabs__content__item\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> tensorflow <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">as<\/span><span class=\"token plain\"> tf<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">from<\/span><span class=\"token plain\"> tensorflow<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">keras <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> layers<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">Encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Layer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">              latent_dim <\/span><span class=\"token operator\">=<\/span><span class=\"token number\">20<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">              name<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'encoder'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">              <\/span><span class=\"token operator\">**<\/span><span class=\"token plain\">kwargs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">super<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">Encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">name<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">name<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">**<\/span><span class=\"token plain\">kwargs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">enc1 <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">500<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> activation<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'relu'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">mean_x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">latent_dim<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">logvar_x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">latent_dim<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">call<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> inputs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">enc1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">inputs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z_mean <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">mean_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z_log_var <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">logvar_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> z_mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_log_var<\/span><\/p><\/pre>\n<\/div>\n<div class=\"tabs__content__item\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> torch<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">functional <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">as<\/span><span class=\"token plain\"> F<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">Encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Module<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> latent_dim<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">20<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">super<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">Encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">enc1 <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Linear<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">784<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">500<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">mean_x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Linear<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">500<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\">latent_dim<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">logvar_x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Linear<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">500<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> latent_dim<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">forward<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\">inputs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">enc1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">inputs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   x<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> F<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">relu<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z_mean <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">mean_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z_log_var <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">logvar_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> z_mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_log_var<\/span><\/p><\/pre>\n<\/div>\n<\/div>\n<\/div>\n<p>\u0415\u0449\u0435 \u043d\u0435\u0441\u043a\u043e\u043b\u044c\u043a\u043e \u0432\u0435\u0449\u0435\u0439, \u043d\u0430 \u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u0441\u043b\u0435\u0434\u0443\u0435\u0442 \u043e\u0431\u0440\u0430\u0442\u0438\u0442\u044c \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435, \u043f\u0440\u0435\u0436\u0434\u0435 \u0447\u0435\u043c \u043c\u044b \u043f\u0440\u043e\u0434\u043e\u043b\u0436\u0438\u043c:<\/p>\n<ul>\n<li>\n<p>\u043b\u044c\u043d\u0430 <code>nn.linen<\/code> \u043f\u0430\u043a\u0435\u0442 \u0441\u043e\u0434\u0435\u0440\u0436\u0438\u0442 \u0431\u043e\u043b\u044c\u0448\u0438\u043d\u0441\u0442\u0432\u043e \u0443\u0440\u043e\u0432\u043d\u0435\u0439 \u0433\u043b\u0443\u0431\u043e\u043a\u043e\u0433\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f \u0438 \u043e\u043f\u0435\u0440\u0430\u0446\u0438\u0438, \u0442\u0430\u043a\u0438\u0435 \u043a\u0430\u043a <code>Dense<\/code>, <code>relu<\/code>\u0438 \u043c\u043d\u043e\u0433\u043e\u0435 \u0434\u0440\u0443\u0433\u043e\u0435<\/p>\n<\/li>\n<li>\n<p>\u041a\u043e\u0434 \u0432\u043e Flax, Tensorflow \u0438 Pytorch \u043f\u0440\u0430\u043a\u0442\u0438\u0447\u0435\u0441\u043a\u0438 \u043d\u0435\u043e\u0442\u043b\u0438\u0447\u0438\u043c \u0434\u0440\u0443\u0433 \u043e\u0442 \u0434\u0440\u0443\u0433\u0430.<\/p>\n<\/li>\n<\/ul>\n<h2 id=\"the-decoder\">\u0414\u0435\u043a\u043e\u0434\u0435\u0440<\/h2>\n<p>\u041e\u0447\u0435\u043d\u044c \u043f\u043e\u0445\u043e\u0436\u0438\u043c \u043e\u0431\u0440\u0430\u0437\u043e\u043c \u043c\u044b \u043c\u043e\u0436\u0435\u043c \u0440\u0430\u0437\u0440\u0430\u0431\u043e\u0442\u0430\u0442\u044c \u0434\u0435\u043a\u043e\u0434\u0435\u0440 \u0432\u043e \u0432\u0441\u0435\u0445 \u0442\u0440\u0435\u0445 \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a\u0430\u0445.  \u0414\u0435\u043a\u043e\u0434\u0435\u0440 \u0431\u0443\u0434\u0435\u0442 \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u044f\u0442\u044c \u0441\u043e\u0431\u043e\u0439 \u0434\u0432\u0430 \u043b\u0438\u043d\u0435\u0439\u043d\u044b\u0445 \u0441\u043b\u043e\u044f, \u043f\u043e\u043b\u0443\u0447\u0430\u044e\u0449\u0438\u0445 \u0441\u043a\u0440\u044b\u0442\u043e\u0435 \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u0438\u0435 <span class=\"inlineMath\"><span class=\"katex\"><span class=\"katex-mathml\"><math xmlns=\"http:\/\/www.w3.org\/1998\/Math\/MathML\"><semantics><mrow><mi>\u0433<\/mi><\/mrow><annotation encoding=\"application\/x-tex\">\u0433<\/annotation><\/semantics><\/math><\/span><span class=\"katex-html\" aria-hidden=\"true\"><span class=\"base\"><span class=\"strut\" style=\"height:0.43056em;vertical-align:0em\"\/><span class=\"mord mathnormal\" style=\"margin-right:0.04398em\">\u0433<\/span><\/span><\/span><\/span><\/span>  \u0438 \u0432\u044b\u0432\u0435\u0434\u0438\u0442\u0435 \u0432\u043e\u0441\u0441\u0442\u0430\u043d\u043e\u0432\u043b\u0435\u043d\u043d\u044b\u0439 \u0432\u0432\u043e\u0434.<\/p>\n<p>\u041e\u043f\u044f\u0442\u044c \u0436\u0435, \u0440\u0435\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u0438 \u043e\u0447\u0435\u043d\u044c \u043f\u043e\u0445\u043e\u0436\u0438.<\/p>\n<div class=\"tabs\">\n<div class=\"tabs__content\">\n<div class=\"tabs__content__item--active\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">Decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Module<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">@nn<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">compact<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__call__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">500<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> name<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'fc1'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">relu<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">784<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> name<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'fc2'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> z<\/span><\/p><\/pre>\n<\/div>\n<div class=\"tabs__content__item\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">Decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Layer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">               name<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'decoder'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">               <\/span><span class=\"token operator\">**<\/span><span class=\"token plain\">kwargs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">super<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">Decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">name<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">name<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">**<\/span><span class=\"token plain\">kwargs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">dec1 <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">500<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> activation<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'relu'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">out <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> layers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Dense<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">784<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">call<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">dec1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">out<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<\/div>\n<div class=\"tabs__content__item\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">Decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Module<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> latent_dim<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">20<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">super<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">Decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">dec1 <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Linear<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">latent_dim<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">500<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">out <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Linear<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">500<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">784<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">forward<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">dec1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> F<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">relu<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">out<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<\/div>\n<\/div>\n<\/div>\n<h2 id=\"variational-autoencoder\">\u0412\u0430\u0440\u0438\u0430\u0446\u0438\u043e\u043d\u043d\u044b\u0439 \u0430\u0432\u0442\u043e\u044d\u043d\u043a\u043e\u0434\u0435\u0440<\/h2>\n<p>\u0427\u0442\u043e\u0431\u044b \u043e\u0431\u044a\u0435\u0434\u0438\u043d\u0438\u0442\u044c \u043a\u043e\u0434\u0438\u0440\u043e\u0432\u0449\u0438\u043a \u0438 \u0434\u0435\u043a\u043e\u0434\u0435\u0440, \u0434\u0430\u0432\u0430\u0439\u0442\u0435 \u0441\u043e\u0437\u0434\u0430\u0434\u0438\u043c \u0435\u0449\u0435 \u043e\u0434\u0438\u043d \u043a\u043b\u0430\u0441\u0441, \u043d\u0430\u0437\u044b\u0432\u0430\u0435\u043c\u044b\u0439 <code>VAE<\/code>, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u0431\u0443\u0434\u0435\u0442 \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u044f\u0442\u044c \u0432\u0441\u044e \u0430\u0440\u0445\u0438\u0442\u0435\u043a\u0442\u0443\u0440\u0443.  \u0417\u0434\u0435\u0441\u044c \u043d\u0430\u043c \u0442\u0430\u043a\u0436\u0435 \u043d\u0443\u0436\u043d\u043e \u043d\u0430\u043f\u0438\u0441\u0430\u0442\u044c \u043a\u043e\u0434 \u0434\u043b\u044f \u0442\u0440\u044e\u043a\u0430 \u0441 \u0440\u0435\u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u0438\u0437\u0430\u0446\u0438\u0435\u0439.  \u0412 \u0446\u0435\u043b\u043e\u043c \u0438\u043c\u0435\u0435\u043c: \u0441\u043a\u0440\u044b\u0442\u0430\u044f \u043f\u0435\u0440\u0435\u043c\u0435\u043d\u043d\u0430\u044f \u0438\u0437 \u043a\u043e\u0434\u0438\u0440\u043e\u0432\u0449\u0438\u043a\u0430 \u043f\u0435\u0440\u0435\u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u0438\u0440\u0443\u0435\u0442\u0441\u044f \u0438 \u043f\u043e\u0434\u0430\u0435\u0442\u0441\u044f \u0432 \u0434\u0435\u043a\u043e\u0434\u0435\u0440, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u043f\u0440\u043e\u0438\u0437\u0432\u043e\u0434\u0438\u0442 \u0440\u0435\u043a\u043e\u043d\u0441\u0442\u0440\u0443\u0438\u0440\u043e\u0432\u0430\u043d\u043d\u044b\u0439 \u0432\u0432\u043e\u0434.<\/p>\n<p>\u041d\u0430\u043f\u043e\u043c\u0438\u043d\u0430\u0435\u043c, \u0432\u043e\u0442 \u0438\u043d\u0442\u0443\u0438\u0442\u0438\u0432\u043d\u043e\u0435 \u0438\u0437\u043e\u0431\u0440\u0430\u0436\u0435\u043d\u0438\u0435, \u043e\u0431\u044a\u044f\u0441\u043d\u044f\u044e\u0449\u0435\u0435 \u0442\u0440\u044e\u043a \u0441 \u0440\u0435\u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u0438\u0437\u0430\u0446\u0438\u0435\u0439:<\/p>\n<p><span class=\"gatsby-resp-image-wrapper\" style=\"position:relative;display:block;margin-left:auto;margin-right:auto;max-width:700px\"><\/p>\n<p>    <span class=\"gatsby-resp-image-background-image\" style=\"padding-bottom:43.66666666666667%;position:relative;bottom:0;left:0;background-image:url('data:image\/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAJCAIAAAC9o5sfAAAACXBIWXMAAAsTAAALEwEAmpwYAAABgUlEQVQoz11RS0oDQRCdW7gRXOgJ3HkDd3oFDyIIWQlRBJdmm4BoliIqfhMlRMQYTEJAQjL\/iXEm02amp\/92d6Koj6bpetSrelVtiF\/g+iZMlNsTJyaK0RSDmSCk3EmaPvohJYzbQdbQFOPqSLwndGHPPHoF8k25SiXeMDHd5YK\/XY0UyWZqo9SIzruxLPfNKHQ\/MPoVSj0PQwDJJyR\/OlNKCaUwhTLACIEUpkNH5NZ541I7ZuoihHpB\/iG+eEsEzQAAGGMlFjclXj\/R9RkjmGI0dh28tSaaV1OxagPGcJKuFoP83Ui6HqdItlFidHqAKsdMixHlLx6KElLz8ITMvEEneOwM45Q23bQf0ixwhdnmeknGhNKEUIRV7lkPzu+adRsu7VvFBpiKWz6c2xlc95KVgp174uJwE28sgiiM4tjQHzRbAUCsamZywzUXh5BNxdJOxUKEi+cAO3IzocVb91wNxAzxH9y2rMDqB66L9GAS8Sgw+\/2RZ\/u2NXA8RGZ1vwCckvQI\/Hs8VAAAAABJRU5ErkJggg==');background-size:cover;display:block\"\/><br \/>\n  <img decoding=\"async\" class=\"gatsby-resp-image-image\" alt=\"\u0440\u0435\u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u0438\u0437\u0430\u0446\u0438\u044f-\u0442\u0440\u044e\u043a\" title=\"\u0440\u0435\u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u0438\u0437\u0430\u0446\u0438\u044f-\u0442\u0440\u044e\u043a\" src=\"https:\/\/theaisummer.com\/static\/0bda23009edd2eab6844869ec2700aab\/8c557\/reparameterization-trick.png\" srcset=\"\/static\/0bda23009edd2eab6844869ec2700aab\/5a46d\/reparameterization-trick.png 300w,\/static\/0bda23009edd2eab6844869ec2700aab\/0a47e\/reparameterization-trick.png 600w,\/static\/0bda23009edd2eab6844869ec2700aab\/8c557\/reparameterization-trick.png 700w\" sizes=\"(max-width: 700px) 100vw, 700px\" style=\"width:100%;height:100%;margin:0;vertical-align:middle;position:absolute;top:0;left:0\" loading=\"lazy\"\/><\/p>\n<p>    <\/span><br \/>\n<em>\u0418\u0441\u0442\u043e\u0447\u043d\u0438\u043a: \u0410\u043b\u0435\u043a\u0441\u0430\u043d\u0434\u0440 \u0410\u043c\u0438\u043d\u0438 \u0438 \u0410\u0432\u0430 \u0421\u043e\u043b\u0435\u0439\u043c\u0430\u043d\u0438, \u0413\u043b\u0443\u0431\u043e\u043a\u043e\u0435 \u0433\u0435\u043d\u0435\u0440\u0430\u0442\u0438\u0432\u043d\u043e\u0435 \u043c\u043e\u0434\u0435\u043b\u0438\u0440\u043e\u0432\u0430\u043d\u0438\u0435 |  \u041c\u0430\u0441\u0441\u0430\u0447\u0443\u0441\u0435\u0442\u0441\u043a\u0438\u0439 \u0442\u0435\u0445\u043d\u043e\u043b\u043e\u0433\u0438\u0447\u0435\u0441\u043a\u0438\u0439 \u0438\u043d\u0441\u0442\u0438\u0442\u0443\u0442 6.S191, http:\/\/introtodeeplearning.com\/<\/em><\/p>\n<p>\u041e\u0431\u0440\u0430\u0442\u0438\u0442\u0435 \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435, \u0447\u0442\u043e \u043d\u0430 \u044d\u0442\u043e\u0442 \u0440\u0430\u0437 \u0432 JAX \u043c\u044b \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u0435\u043c <code>setup<\/code> \u043c\u0435\u0442\u043e\u0434 \u0432\u043c\u0435\u0441\u0442\u043e <code>nn.compact<\/code> \u0430\u043d\u043d\u043e\u0442\u0430\u0446\u0438\u044f.  \u041a\u0440\u043e\u043c\u0435 \u0442\u043e\u0433\u043e, \u043f\u0440\u043e\u0432\u0435\u0440\u044c\u0442\u0435, \u043d\u0430\u0441\u043a\u043e\u043b\u044c\u043a\u043e \u043f\u043e\u0445\u043e\u0436\u0438 \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u0440\u0435\u043f\u0430\u0440\u0430\u043c\u0435\u0442\u0440\u0438\u0437\u0430\u0446\u0438\u0438.  \u041a\u043e\u043d\u0435\u0447\u043d\u043e, \u043a\u0430\u0436\u0434\u044b\u0439 \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u0435\u0442 \u0441\u0432\u043e\u0438 \u0441\u043e\u0431\u0441\u0442\u0432\u0435\u043d\u043d\u044b\u0435 \u0444\u0443\u043d\u043a\u0446\u0438\u0438 \u0438 \u043e\u043f\u0435\u0440\u0430\u0446\u0438\u0438, \u043d\u043e \u043e\u0431\u0449\u0438\u0439 \u043e\u0431\u0440\u0430\u0437 \u043f\u043e\u0447\u0442\u0438 \u0438\u0434\u0435\u043d\u0442\u0438\u0447\u0435\u043d.<\/p>\n<div class=\"tabs\">\n<div class=\"tabs__content\">\n<div class=\"tabs__content__item--active\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">VAE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Module<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> latents<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">int<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">20<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">setup<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">encoder <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> Encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">latents<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">decoder <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> Decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__call__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> reparameterize<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z_rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   recon_x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> recon_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">reparameterize<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> std <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">exp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">0.5<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> eps <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> random<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">normal<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">shape<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> mean <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> eps <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> std<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> VAE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">latents<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">LATENTS<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<\/div>\n<div class=\"tabs__content__item\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">VAE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">keras<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">              latent_dim<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">20<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">              name<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'vae'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">              <\/span><span class=\"token operator\">**<\/span><span class=\"token plain\">kwargs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">super<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">VAE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">name<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">name<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">**<\/span><span class=\"token plain\">kwargs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">encoder <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> Encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">latent_dim<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">latent_dim<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">decoder <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> Decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">call<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> inputs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z_mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_log_var <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">inputs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   z <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">reparameterize<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z_mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_log_var<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   reconstructed <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> reconstructed<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_log_var<\/span><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">reparameterize<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">     eps <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">random<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">normal<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">shape<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">shape<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">     <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> mean <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> eps <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">exp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">logvar <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">.5<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<\/div>\n<div class=\"tabs__content__item\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">class<\/span><span class=\"token plain\"> <\/span><span class=\"token class-name\">VAE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Module<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> latent_dim<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">20<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">       <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">super<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">VAE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">__init__<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">       self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">encoder <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> Encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">latent_dim<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">       self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">decoder <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> Decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">latent_dim<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">forward<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\">inputs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">     z_mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_log_var <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">encoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">inputs<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">     z <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">reparameterize<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z_mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_log_var<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">     reconstructed <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">decoder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">     <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> reconstructed<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_log_var<\/span><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">reparameterize<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">self<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mu<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> log_var<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">     std <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">exp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">0.5<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> log_var<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">     eps <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">randn_like<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">std<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">     <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> mu <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">eps <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> std<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<\/div>\n<\/div>\n<\/div>\n<h2 id=\"loss-and-training-step\">\u041f\u043e\u0442\u0435\u0440\u044f \u0438 \u0448\u0430\u0433 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f<\/h2>\n<p>\u0412\u0441\u0435 \u043d\u0430\u0447\u0438\u043d\u0430\u0435\u0442 \u043c\u0435\u043d\u044f\u0442\u044c\u0441\u044f, \u043a\u043e\u0433\u0434\u0430 \u043c\u044b \u043d\u0430\u0447\u0438\u043d\u0430\u0435\u043c \u0440\u0435\u0430\u043b\u0438\u0437\u043e\u0432\u044b\u0432\u0430\u0442\u044c \u044d\u0442\u0430\u043f \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f \u0438 \u0444\u0443\u043d\u043a\u0446\u0438\u044e \u043f\u043e\u0442\u0435\u0440\u044c.  \u041d\u043e \u043d\u0435 \u043d\u0430\u043c\u043d\u043e\u0433\u043e.<\/p>\n<ol>\n<li>\n<p>\u0427\u0442\u043e\u0431\u044b \u0432 \u043f\u043e\u043b\u043d\u043e\u0439 \u043c\u0435\u0440\u0435 \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u0442\u044c \u0432\u043e\u0437\u043c\u043e\u0436\u043d\u043e\u0441\u0442\u0438 JAX, \u043d\u0430\u043c \u043d\u0443\u0436\u043d\u043e \u0434\u043e\u0431\u0430\u0432\u0438\u0442\u044c \u0432 \u043d\u0430\u0448 \u043a\u043e\u0434 \u0430\u0432\u0442\u043e\u043c\u0430\u0442\u0438\u0447\u0435\u0441\u043a\u0443\u044e \u0432\u0435\u043a\u0442\u043e\u0440\u0438\u0437\u0430\u0446\u0438\u044e \u0438 XLA-\u043a\u043e\u043c\u043f\u0438\u043b\u044f\u0446\u0438\u044e.  \u042d\u0442\u043e \u043c\u043e\u0436\u043d\u043e \u043b\u0435\u0433\u043a\u043e \u0441\u0434\u0435\u043b\u0430\u0442\u044c \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e <code>vmap<\/code> \u0438 <code>jit<\/code> \u0430\u043d\u043d\u043e\u0442\u0430\u0446\u0438\u0438.<\/p>\n<\/li>\n<li>\n<p>\u0411\u043e\u043b\u0435\u0435 \u0442\u043e\u0433\u043e, \u043c\u044b \u0434\u043e\u043b\u0436\u043d\u044b \u0432\u043a\u043b\u044e\u0447\u0438\u0442\u044c \u0430\u0432\u0442\u043e\u043c\u0430\u0442\u0438\u0447\u0435\u0441\u043a\u0443\u044e \u0434\u0438\u0444\u0444\u0435\u0440\u0435\u043d\u0446\u0438\u0430\u0446\u0438\u044e, \u043a\u043e\u0442\u043e\u0440\u0443\u044e \u043c\u043e\u0436\u043d\u043e \u0432\u044b\u043f\u043e\u043b\u043d\u0438\u0442\u044c \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e <code>grad_fn<\/code> \u0442\u0440\u0430\u043d\u0441\u0444\u043e\u0440\u043c\u0430\u0446\u0438\u044f<\/p>\n<\/li>\n<li>\n<p>\u041c\u044b \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u0435\u043c <code>flax.optim<\/code> \u043f\u0430\u043a\u0435\u0442 \u0430\u043b\u0433\u043e\u0440\u0438\u0442\u043c\u043e\u0432 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0446\u0438\u0438<\/p>\n<\/li>\n<\/ol>\n<p>\u0415\u0449\u0435 \u043e\u0434\u043d\u043e \u043d\u0435\u0431\u043e\u043b\u044c\u0448\u043e\u0435 \u043e\u0442\u043b\u0438\u0447\u0438\u0435, \u043e \u043a\u043e\u0442\u043e\u0440\u043e\u043c \u043d\u0430\u043c \u043d\u0443\u0436\u043d\u043e \u0437\u043d\u0430\u0442\u044c, \u0437\u0430\u043a\u043b\u044e\u0447\u0430\u0435\u0442\u0441\u044f \u0432 \u0442\u043e\u043c, \u043a\u0430\u043a \u043c\u044b \u043f\u0435\u0440\u0435\u0434\u0430\u0435\u043c \u0434\u0430\u043d\u043d\u044b\u0435 \u0432 \u043d\u0430\u0448\u0443 \u043c\u043e\u0434\u0435\u043b\u044c.  \u042d\u0442\u043e \u043c\u043e\u0436\u0435\u0442 \u0431\u044b\u0442\u044c \u0434\u043e\u0441\u0442\u0438\u0433\u043d\u0443\u0442\u043e \u0441 \u043f\u043e\u043c\u043e\u0449\u044c\u044e \u043c\u0435\u0442\u043e\u0434\u0430 apply \u0432 \u0432\u0438\u0434\u0435 <code>model().apply({'params': params}, batch, z_rng)<\/code>\u0433\u0434\u0435 <code>batch<\/code> \u043d\u0430\u0448\u0438 \u0442\u0440\u0435\u043d\u0438\u0440\u043e\u0432\u043e\u0447\u043d\u044b\u0435 \u0434\u0430\u043d\u043d\u044b\u0435.<\/p>\n<div class=\"tabs\">\n<div class=\"tabs__content\">\n<div class=\"tabs__content__item--active\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">@jax<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">vmap<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">kl_divergence<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token number\">0.5<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">sum<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">1<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> logvar <\/span><span class=\"token operator\">-<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">square<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">exp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">@jax<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">vmap<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">binary_cross_entropy_with_logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> labels<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> logits <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">log_sigmoid<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token plain\">jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">sum<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">labels <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> logits <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token plain\"> labels<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">log<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token operator\">-<\/span><span class=\"token plain\">jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">expm1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">@jax<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">jit<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">train_step<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> batch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">  <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">loss_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   recon_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">apply<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">{<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'params'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"> params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">}<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> batch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   bce_loss <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> binary_cross_entropy_with_logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">recon_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> batch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   kld_loss <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> kl_divergence<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   loss <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> bce_loss <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> kld_loss<\/span><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> recon_x<\/span><\/p><p><span class=\"token plain\"> grad_fn <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">value_and_grad<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">loss_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> has_aux<\/span><span class=\"token operator\">=<\/span><span class=\"token boolean\">True<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> _<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> grad <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> grad_fn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">target<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> optimizer <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">apply_gradient<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">grad<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> optimizer<\/span><\/p><\/pre>\n<\/div>\n<div class=\"tabs__content__item\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">kl_divergence<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token number\">0.5<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">reduce_sum<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">           <\/span><span class=\"token number\">1<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> logvar <\/span><span class=\"token operator\">-<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">square<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">           tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">exp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> axis<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">binary_cross_entropy_with_logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> labels<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> logits <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">math<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">log<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">reduce_sum<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">           labels <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> logits <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">           <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">1<\/span><span class=\"token operator\">-<\/span><span class=\"token plain\">labels<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">math<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">log<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token operator\">-<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">math<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">expm1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">           axis<\/span><span class=\"token operator\">=<\/span><span class=\"token number\">1<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">       <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">@tf<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token decorator annotation punctuation\" style=\"color:rgb(248, 248, 242)\">function<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">train_step<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">  <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">with<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">GradientTape<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">as<\/span><span class=\"token plain\"> tape<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   recon_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   bce_loss <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">reduce_mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">binary_cross_entropy_with_logits<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">recon_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> batch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   kld_loss <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">reduce_mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">kl_divergence<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">mean<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   loss <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> bce_loss <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> kld_loss<\/span><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">print<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> kld_loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> bce_loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> gradients <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> tape<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">gradient<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">trainable_variables<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">apply_gradients<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">zip<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">gradients<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">trainable_variables<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<\/div>\n<div class=\"tabs__content__item\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">final_loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">reconstruction<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> train_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mu<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   BCE <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">nn<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">BCEWithLogitsLoss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">reduction<\/span><span class=\"token operator\">=<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'sum'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">reconstruction<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> train_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   KLD <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token number\">0.5<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">*<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">sum<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">1<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> logvar <\/span><span class=\"token operator\">-<\/span><span class=\"token plain\"> mu<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">pow<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">2<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">-<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">exp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> BCE <\/span><span class=\"token operator\">+<\/span><span class=\"token plain\"> KLD<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">train_step<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">train_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> train_x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> torch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">from_numpy<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">train_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">zero_grad<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> reconstruction<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> mu<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">train_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> loss <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> final_loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">reconstruction<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> train_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\">  mu<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> logvar<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> running_loss <\/span><span class=\"token operator\">+=<\/span><span class=\"token plain\"> loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">item<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> loss<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">backward<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">step<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<\/div>\n<\/div>\n<\/div>\n<p>\u041f\u043e\u043c\u043d\u0438\u0442\u0435, \u0447\u0442\u043e VAE \u043e\u0431\u0443\u0447\u0430\u044e\u0442\u0441\u044f \u043f\u0443\u0442\u0435\u043c \u043c\u0430\u043a\u0441\u0438\u043c\u0438\u0437\u0430\u0446\u0438\u0438 \u043d\u0438\u0436\u043d\u0435\u0439 \u0433\u0440\u0430\u043d\u0438\u0446\u044b \u0434\u043e\u043a\u0430\u0437\u0430\u0442\u0435\u043b\u044c\u0441\u0442\u0432, \u0438\u0437\u0432\u0435\u0441\u0442\u043d\u043e\u0439 \u043a\u0430\u043a ELBO.<\/p>\n<div class=\"math\"><span class=\"katex-display\"><span class=\"katex\"><span class=\"katex-mathml\"><math xmlns=\"http:\/\/www.w3.org\/1998\/Math\/MathML\" display=\"block\"><semantics><mrow><msub><mi>\u043b<\/mi><mrow><mi>\u03b8<\/mi><mo separator=\"true\">,<\/mo><mi>\u0444<\/mi><\/mrow><\/msub><mo stretchy=\"false\">(<\/mo><mi>\u0418\u043a\u0441<\/mi><mo stretchy=\"false\">)<\/mo><mo>&#8220;=&#8221;<\/mo><msub><mtext mathvariant=\"bold\">\u0415<\/mtext><mrow><msub><mi>\u0434<\/mi><mi>\u0444<\/mi><\/msub><mo stretchy=\"false\">(<\/mo><mi>\u0433<\/mi><mi mathvariant=\"normal\">\u2223<\/mi><mi>\u0418\u043a\u0441<\/mi><mo stretchy=\"false\">)<\/mo><\/mrow><\/msub><mo stretchy=\"false\">[<\/mo><mi>l<\/mi><mi>o<\/mi><mi>g<\/mi><msub><mi>p<\/mi><mi>\u03b8<\/mi><\/msub><mo stretchy=\"false\">(<\/mo><mi>x<\/mi><mi mathvariant=\"normal\">\u2223<\/mi><mi>z<\/mi><mo stretchy=\"false\">)<\/mo><mo stretchy=\"false\">]<\/mo><mo>\u2212<\/mo><mtext mathvariant=\"bold\">\u041a\u041b<\/mtext><mo stretchy=\"false\">(<\/mo><msub><mi>\u0434<\/mi><mi>\u0444<\/mi><\/msub><mo stretchy=\"false\">(<\/mo><mi>\u0433<\/mi><mi mathvariant=\"normal\">\u2223<\/mi><mi>\u0418\u043a\u0441<\/mi><mo stretchy=\"false\">)<\/mo><mi mathvariant=\"normal\">\u2223<\/mi><mi mathvariant=\"normal\">\u2223<\/mi><msub><mi>\u043f<\/mi><mi>\u03b8<\/mi><\/msub><mo stretchy=\"false\">(<\/mo><mi>\u0433<\/mi><mo stretchy=\"false\">)<\/mo><mo stretchy=\"false\">)<\/mo><\/mrow><annotation encoding=\"application\/x-tex\">L _ {\\ theta, \\ phi} (x) = \\ textbf {E} _ {q _ {\\ phi} (z | x)} [ log p_{\\theta}(x|z) ] &#8211; \\textbf{KL}(q_{\\phi}(z |x) || p_{\\theta}(z))<\/annotation><\/semantics><\/math><\/span><span class=\"katex-html\" aria-hidden=\"true\"><span class=\"base\"><span class=\"strut\" style=\"height:1.036108em;vertical-align:-0.286108em\"\/><span class=\"mord\"><span class=\"mord mathnormal\">\u043b<\/span><span class=\"msupsub\"><span class=\"vlist-t vlist-t2\"><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.3361079999999999em\"><span style=\"top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em\"><span class=\"pstrut\" style=\"height:2.7em\"\/><span class=\"sizing reset-size6 size3 mtight\"><span class=\"mord mtight\"><span class=\"mord mathnormal mtight\" style=\"margin-right:0.02778em\">\u03b8<\/span><span class=\"mpunct mtight\">,<\/span><span class=\"mord mathnormal mtight\">\u0444<\/span><\/span><\/span><\/span><\/span><span class=\"vlist-s\">\u200b<\/span><\/span><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.286108em\"><span\/><\/span><\/span><\/span><\/span><\/span><span class=\"mopen\">(<\/span><span class=\"mord mathnormal\">\u0418\u043a\u0441<\/span><span class=\"mclose\">)<\/span><span class=\"mspace\" style=\"margin-right:0.2777777777777778em\"\/><span class=\"mrel\">&#8220;=&#8221;<\/span><span class=\"mspace\" style=\"margin-right:0.2777777777777778em\"\/><\/span><span class=\"base\"><span class=\"strut\" style=\"height:1.1332799999999998em;vertical-align:-0.38327999999999984em\"\/><span class=\"mord\"><span class=\"mord text\"><span class=\"mord textbf\">\u0415<\/span><\/span><span class=\"msupsub\"><span class=\"vlist-t vlist-t2\"><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.3448em\"><span style=\"top:-2.5198em;margin-right:0.05em\"><span class=\"pstrut\" style=\"height:2.7em\"\/><span class=\"sizing reset-size6 size3 mtight\"><span class=\"mord mtight\"><span class=\"mord mtight\"><span class=\"mord mathnormal mtight\" style=\"margin-right:0.03588em\">\u0434<\/span><span class=\"msupsub\"><span class=\"vlist-t vlist-t2\"><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.3448em\"><span style=\"top:-2.3487714285714287em;margin-left:-0.03588em;margin-right:0.07142857142857144em\"><span class=\"pstrut\" style=\"height:2.5em\"\/><span class=\"sizing reset-size3 size1 mtight\"><span class=\"mord mtight\"><span class=\"mord mathnormal mtight\">\u0444<\/span><\/span><\/span><\/span><\/span><span class=\"vlist-s\">\u200b<\/span><\/span><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.29011428571428566em\"><span\/><\/span><\/span><\/span><\/span><\/span><span class=\"mopen mtight\">(<\/span><span class=\"mord mathnormal mtight\" style=\"margin-right:0.04398em\">\u0433<\/span><span class=\"mord mtight\">\u2223<\/span><span class=\"mord mathnormal mtight\">\u0418\u043a\u0441<\/span><span class=\"mclose mtight\">)<\/span><\/span><\/span><\/span><\/span><span class=\"vlist-s\">\u200b<\/span><\/span><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.38327999999999984em\"><span\/><\/span><\/span><\/span><\/span><\/span><span class=\"mopen\">[<\/span><span class=\"mord mathnormal\" style=\"margin-right:0.01968em\">l<\/span><span class=\"mord mathnormal\">o<\/span><span class=\"mord mathnormal\" style=\"margin-right:0.03588em\">g<\/span><span class=\"mord\"><span class=\"mord mathnormal\">p<\/span><span class=\"msupsub\"><span class=\"vlist-t vlist-t2\"><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.33610799999999996em\"><span style=\"top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em\"><span class=\"pstrut\" style=\"height:2.7em\"\/><span class=\"sizing reset-size6 size3 mtight\"><span class=\"mord mtight\"><span class=\"mord mathnormal mtight\" style=\"margin-right:0.02778em\">\u03b8<\/span><\/span><\/span><\/span><\/span><span class=\"vlist-s\">\u200b<\/span><\/span><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.15em\"><span\/><\/span><\/span><\/span><\/span><\/span><span class=\"mopen\">(<\/span><span class=\"mord mathnormal\">x<\/span><span class=\"mord\">\u2223<\/span><span class=\"mord mathnormal\" style=\"margin-right:0.04398em\">z<\/span><span class=\"mclose\">)<\/span><span class=\"mclose\">]<\/span><span class=\"mspace\" style=\"margin-right:0.2222222222222222em\"\/><span class=\"mbin\">\u2212<\/span><span class=\"mspace\" style=\"margin-right:0.2222222222222222em\"\/><\/span><span class=\"base\"><span class=\"strut\" style=\"height:1.036108em;vertical-align:-0.286108em\"\/><span class=\"mord text\"><span class=\"mord textbf\">\u041a\u041b<\/span><\/span><span class=\"mopen\">(<\/span><span class=\"mord\"><span class=\"mord mathnormal\" style=\"margin-right:0.03588em\">\u0434<\/span><span class=\"msupsub\"><span class=\"vlist-t vlist-t2\"><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.3361079999999999em\"><span style=\"top:-2.5500000000000003em;margin-left:-0.03588em;margin-right:0.05em\"><span class=\"pstrut\" style=\"height:2.7em\"\/><span class=\"sizing reset-size6 size3 mtight\"><span class=\"mord mtight\"><span class=\"mord mathnormal mtight\">\u0444<\/span><\/span><\/span><\/span><\/span><span class=\"vlist-s\">\u200b<\/span><\/span><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.286108em\"><span\/><\/span><\/span><\/span><\/span><\/span><span class=\"mopen\">(<\/span><span class=\"mord mathnormal\" style=\"margin-right:0.04398em\">\u0433<\/span><span class=\"mord\">\u2223<\/span><span class=\"mord mathnormal\">\u0418\u043a\u0441<\/span><span class=\"mclose\">)<\/span><span class=\"mord\">\u2223<\/span><span class=\"mord\">\u2223<\/span><span class=\"mord\"><span class=\"mord mathnormal\">\u043f<\/span><span class=\"msupsub\"><span class=\"vlist-t vlist-t2\"><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.33610799999999996em\"><span style=\"top:-2.5500000000000003em;margin-left:0em;margin-right:0.05em\"><span class=\"pstrut\" style=\"height:2.7em\"\/><span class=\"sizing reset-size6 size3 mtight\"><span class=\"mord mtight\"><span class=\"mord mathnormal mtight\" style=\"margin-right:0.02778em\">\u03b8<\/span><\/span><\/span><\/span><\/span><span class=\"vlist-s\">\u200b<\/span><\/span><span class=\"vlist-r\"><span class=\"vlist\" style=\"height:0.15em\"><span\/><\/span><\/span><\/span><\/span><\/span><span class=\"mopen\">(<\/span><span class=\"mord mathnormal\" style=\"margin-right:0.04398em\">\u0433<\/span><span class=\"mclose\">)<\/span><span class=\"mclose\">)<\/span><\/span><\/span><\/span><\/span><\/div>\n<h2 id=\"training-loop\">\u0422\u0440\u0435\u043d\u0438\u0440\u043e\u0432\u043e\u0447\u043d\u044b\u0439 \u0446\u0438\u043a\u043b<\/h2>\n<p>\u041d\u0430\u043a\u043e\u043d\u0435\u0446, \u043f\u0440\u0438\u0448\u043b\u043e \u0432\u0440\u0435\u043c\u044f \u0434\u043b\u044f \u0432\u0441\u0435\u0433\u043e \u0446\u0438\u043a\u043b\u0430 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u0431\u0443\u0434\u0435\u0442 \u0432\u044b\u043f\u043e\u043b\u043d\u044f\u0442\u044c <code>train_step<\/code> \u0444\u0443\u043d\u043a\u0446\u0438\u043e\u043d\u0438\u0440\u043e\u0432\u0430\u0442\u044c \u0438\u0442\u0435\u0440\u0430\u0442\u0438\u0432\u043d\u043e.<\/p>\n<p>\u0412\u043e Flax \u043c\u043e\u0434\u0435\u043b\u044c \u0434\u043e\u043b\u0436\u043d\u0430 \u0431\u044b\u0442\u044c \u0438\u043d\u0438\u0446\u0438\u0430\u043b\u0438\u0437\u0438\u0440\u043e\u0432\u0430\u043d\u0430 \u043f\u0435\u0440\u0435\u0434 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u0435\u043c, \u043a\u043e\u0442\u043e\u0440\u043e\u0435 \u0432\u044b\u043f\u043e\u043b\u043d\u044f\u0435\u0442\u0441\u044f <code>init<\/code> \u0444\u0443\u043d\u043a\u0446\u0438\u044f, \u0442\u0430\u043a\u0430\u044f \u043a\u0430\u043a: <code>params = model().init(key, init_data, rng)['params']<\/code>.  \u0410\u043d\u0430\u043b\u043e\u0433\u0438\u0447\u043d\u0430\u044f \u0438\u043d\u0438\u0446\u0438\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u044f \u043d\u0435\u043e\u0431\u0445\u043e\u0434\u0438\u043c\u0430 \u0438 \u0434\u043b\u044f \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0442\u043e\u0440\u0430: <code>optimizer = optim.Adam( learning_rate = LEARNING_RATE ).create( params )<\/code>.<\/p>\n<p><code>jax.device_put<\/code>  \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u0435\u0442\u0441\u044f \u0434\u043b\u044f \u043f\u0435\u0440\u0435\u043d\u043e\u0441\u0430 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0442\u043e\u0440\u0430 \u0432 \u043f\u0430\u043c\u044f\u0442\u044c \u0433\u0440\u0430\u0444\u0438\u0447\u0435\u0441\u043a\u043e\u0433\u043e \u043f\u0440\u043e\u0446\u0435\u0441\u0441\u043e\u0440\u0430.<\/p>\n<div class=\"tabs\">\n<div class=\"tabs__content\">\n<div class=\"tabs__content__item--active\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token plain\">rng <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> random<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">PRNGKey<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> key <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> random<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">split<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">init_data <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">ones<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">BATCH_SIZE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">784<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> jnp<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">float32<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">params <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">init<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">key<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> init_data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'params'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">optimizer <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> optim<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Adam<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">learning_rate<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">LEARNING_RATE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">create<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">params<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">optimizer <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> jax<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">device_put<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> z_key<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> eval_rng <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> random<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">split<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">3<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">z <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> random<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">normal<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">z_key<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">64<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> LATENTS<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">steps_per_epoch <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">50000<\/span><span class=\"token plain\"> <\/span><span class=\"token operator\">\/\/<\/span><span class=\"token plain\"> BATCH_SIZE<\/span><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">for<\/span><span class=\"token plain\"> epoch <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">in<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">range<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">NUM_EPOCHS<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">for<\/span><span class=\"token plain\"> _ <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">in<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">range<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">steps_per_epoch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   batch <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">next<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">train_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> key <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> random<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">split<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">rng<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   optimizer <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> train_step<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> batch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> key<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<\/div>\n<div class=\"tabs__content__item\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token plain\">vae <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> VAE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">latent_dim<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">LATENTS<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">optimizer <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">keras<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">optimizers<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Adam<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">1e<\/span><span class=\"token operator\">-<\/span><span class=\"token number\">4<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">for<\/span><span class=\"token plain\"> epoch <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">in<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">range<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">NUM_EPOCHS<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">for<\/span><span class=\"token plain\"> train_x <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">in<\/span><span class=\"token plain\"> train_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   train_step<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">vae<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> train_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> optimizer<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<\/div>\n<div class=\"tabs__content__item\">\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">train<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\">training_data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   optimizer <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> optim<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Adam<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">model<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">parameters<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> lr<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">LEARNING_RATE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   running_loss <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">0.0<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">   <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">for<\/span><span class=\"token plain\"> epoch <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">in<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">range<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">NUM_EPOCHS<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">       <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">for<\/span><span class=\"token plain\"> i<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> train_x <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">in<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">enumerate<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">training_data<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">          train_step<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">train_x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">vae <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> VAE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">LATENTS<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">train<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">vae<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> train_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<\/div>\n<\/div>\n<\/div>\n<h2 id=\"load-and-process-data\">\u0417\u0430\u0433\u0440\u0443\u0437\u043a\u0430 \u0438 \u043e\u0431\u0440\u0430\u0431\u043e\u0442\u043a\u0430 \u0434\u0430\u043d\u043d\u044b\u0445<\/h2>\n<p>\u041e\u0434\u043d\u0430 \u0432\u0435\u0449\u044c, \u043a\u043e\u0442\u043e\u0440\u0443\u044e \u044f \u043d\u0435 \u0443\u043f\u043e\u043c\u044f\u043d\u0443\u043b, \u044d\u0442\u043e \u0434\u0430\u043d\u043d\u044b\u0435.  \u041a\u0430\u043a \u043c\u044b \u0437\u0430\u0433\u0440\u0443\u0436\u0430\u0435\u043c \u0438 \u043f\u0440\u0435\u0434\u0432\u0430\u0440\u0438\u0442\u0435\u043b\u044c\u043d\u043e \u043e\u0431\u0440\u0430\u0431\u0430\u0442\u044b\u0432\u0430\u0435\u043c \u0434\u0430\u043d\u043d\u044b\u0435 \u0432\u043e Flax?  \u0427\u0442\u043e \u0436, Flax \u0435\u0449\u0435 \u043d\u0435 \u0432\u043a\u043b\u044e\u0447\u0430\u0435\u0442 \u0432 \u0441\u0435\u0431\u044f \u043f\u0430\u043a\u0435\u0442\u044b \u0434\u043b\u044f \u043e\u0431\u0440\u0430\u0431\u043e\u0442\u043a\u0438 \u0434\u0430\u043d\u043d\u044b\u0445, \u043a\u0440\u043e\u043c\u0435 \u0431\u0430\u0437\u043e\u0432\u044b\u0445 \u043e\u043f\u0435\u0440\u0430\u0446\u0438\u0439 <code>jax.numpy<\/code>.  \u0421\u0435\u0439\u0447\u0430\u0441 \u043b\u0443\u0447\u0448\u0435 \u0432\u0441\u0435\u0433\u043e \u0437\u0430\u0438\u043c\u0441\u0442\u0432\u043e\u0432\u0430\u0442\u044c \u043f\u0430\u043a\u0435\u0442\u044b \u0438\u0437 \u0434\u0440\u0443\u0433\u0438\u0445 \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a\u043e\u0432, \u0442\u0430\u043a\u0438\u0445 \u043a\u0430\u043a \u043d\u0430\u0431\u043e\u0440\u044b \u0434\u0430\u043d\u043d\u044b\u0445 Tensorflow (tfds) \u0438\u043b\u0438 Torchvision.  \u0427\u0442\u043e\u0431\u044b \u0441\u0434\u0435\u043b\u0430\u0442\u044c \u0441\u0442\u0430\u0442\u044c\u044e \u0441\u0430\u043c\u043e\u0434\u043e\u0441\u0442\u0430\u0442\u043e\u0447\u043d\u043e\u0439, \u044f \u0432\u043a\u043b\u044e\u0447\u0443 \u043a\u043e\u0434, \u043a\u043e\u0442\u043e\u0440\u044b\u0439 \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u043b \u0434\u043b\u044f \u0437\u0430\u0433\u0440\u0443\u0437\u043a\u0438 \u043f\u0440\u0438\u043c\u0435\u0440\u0430 \u043e\u0431\u0443\u0447\u0430\u044e\u0449\u0435\u0433\u043e \u043d\u0430\u0431\u043e\u0440\u0430 \u0434\u0430\u043d\u043d\u044b\u0445. <code>tfds<\/code>.  \u041d\u0435 \u0441\u0442\u0435\u0441\u043d\u044f\u0439\u0442\u0435\u0441\u044c \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u043e\u0432\u0430\u0442\u044c \u0441\u0432\u043e\u0439 \u0441\u043e\u0431\u0441\u0442\u0432\u0435\u043d\u043d\u044b\u0439 \u0437\u0430\u0433\u0440\u0443\u0437\u0447\u0438\u043a \u0434\u0430\u043d\u043d\u044b\u0445, \u0435\u0441\u043b\u0438 \u0432\u044b \u043f\u043b\u0430\u043d\u0438\u0440\u0443\u0435\u0442\u0435 \u0437\u0430\u043f\u0443\u0441\u043a\u0430\u0442\u044c \u0440\u0435\u0430\u043b\u0438\u0437\u0430\u0446\u0438\u0438, \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u0435\u043d\u043d\u044b\u0435 \u0432 \u044d\u0442\u043e\u0439 \u0441\u0442\u0430\u0442\u044c\u0435.<\/p>\n<pre class=\"prism-code language-python\" style=\"color:#F8F8F2;background-color:#282A36\"><p><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">import<\/span><span class=\"token plain\"> tensorflow_datasets <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">as<\/span><span class=\"token plain\"> tfds<\/span><\/p><p><span class=\"token plain\">tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">config<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">experimental<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">set_visible_devices<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'GPU'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"\/><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">def<\/span><span class=\"token plain\"> <\/span><span class=\"token function\" style=\"color:rgb(80, 250, 123)\">prepare_image<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">:<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">cast<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'image'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">float32<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> x <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> tf<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">reshape<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">x<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token plain\"> <\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token operator\">-<\/span><span class=\"token number\">1<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">,<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\"> <\/span><span class=\"token keyword\" style=\"color:rgb(189, 147, 249);font-style:italic\">return<\/span><span class=\"token plain\"> x<\/span><\/p><p><span class=\"token plain\">ds_builder <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> tfds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">builder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token string\" style=\"color:rgb(255, 121, 198)\">'binarized_mnist'<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">ds_builder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">download_and_prepare<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">train_ds <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> ds_builder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">as_dataset<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">split<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">tfds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Split<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">TRAIN<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">train_ds <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> train_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">map<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">prepare_image<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">train_ds <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> train_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">cache<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">train_ds <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> train_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">repeat<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">train_ds <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> train_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">shuffle<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">50000<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">train_ds <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> train_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">batch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">BATCH_SIZE<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">train_ds <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> <\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">iter<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">tfds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">as_numpy<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">train_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">test_ds <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> ds_builder<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">as_dataset<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">split<\/span><span class=\"token operator\">=<\/span><span class=\"token plain\">tfds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">Split<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">TEST<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">test_ds <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> test_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">map<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">prepare_image<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">batch<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token number\">10000<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token plain\"\/><\/p><p><span class=\"token plain\">test_ds <\/span><span class=\"token operator\">=<\/span><span class=\"token plain\"> np<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">.<\/span><span class=\"token plain\">array<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token builtin\" style=\"color:rgb(189, 147, 249)\">list<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">(<\/span><span class=\"token plain\">test_ds<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">[<\/span><span class=\"token number\">0<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">]<\/span><span class=\"token punctuation\" style=\"color:rgb(248, 248, 242)\">)<\/span><\/p><\/pre>\n<h2 id=\"final-observations\">\u0417\u0430\u043a\u043b\u044e\u0447\u0438\u0442\u0435\u043b\u044c\u043d\u044b\u0435 \u043d\u0430\u0431\u043b\u044e\u0434\u0435\u043d\u0438\u044f<\/h2>\n<p>\u0427\u0442\u043e\u0431\u044b \u0437\u0430\u043a\u0440\u044b\u0442\u044c \u0441\u0442\u0430\u0442\u044c\u044e, \u0434\u0430\u0432\u0430\u0439\u0442\u0435 \u043e\u0431\u0441\u0443\u0434\u0438\u043c \u043d\u0435\u0441\u043a\u043e\u043b\u044c\u043a\u043e \u0437\u0430\u043a\u043b\u044e\u0447\u0438\u0442\u0435\u043b\u044c\u043d\u044b\u0445 \u043d\u0430\u0431\u043b\u044e\u0434\u0435\u043d\u0438\u0439, \u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u043f\u043e\u044f\u0432\u043b\u044f\u044e\u0442\u0441\u044f \u043f\u043e\u0441\u043b\u0435 \u0442\u0449\u0430\u0442\u0435\u043b\u044c\u043d\u043e\u0433\u043e \u0430\u043d\u0430\u043b\u0438\u0437\u0430 \u043a\u043e\u0434\u0430:<\/p>\n<ul>\n<li>\n<p>\u0412\u0441\u0435 3 \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a\u0430 \u0441\u043e\u043a\u0440\u0430\u0442\u0438\u043b\u0438 \u0448\u0430\u0431\u043b\u043e\u043d\u043d\u044b\u0439 \u043a\u043e\u0434 \u0434\u043e \u043c\u0438\u043d\u0438\u043c\u0443\u043c\u0430, \u0430 Flax \u0442\u0440\u0435\u0431\u0443\u0435\u0442 \u043d\u0435\u043c\u043d\u043e\u0433\u043e \u0431\u043e\u043b\u044c\u0448\u0435, \u043e\u0441\u043e\u0431\u0435\u043d\u043d\u043e \u0432 \u0447\u0430\u0441\u0442\u0438 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f.  \u041e\u0434\u043d\u0430\u043a\u043e \u044d\u0442\u043e \u0434\u0435\u043b\u0430\u0435\u0442\u0441\u044f \u0442\u043e\u043b\u044c\u043a\u043e \u0434\u043b\u044f \u0442\u043e\u0433\u043e, \u0447\u0442\u043e\u0431\u044b \u0433\u0430\u0440\u0430\u043d\u0442\u0438\u0440\u043e\u0432\u0430\u0442\u044c, \u0447\u0442\u043e \u043c\u044b \u0438\u0441\u043f\u043e\u043b\u044c\u0437\u0443\u0435\u043c \u0432\u0441\u0435 \u0434\u043e\u0441\u0442\u0443\u043f\u043d\u044b\u0435 \u043f\u0440\u0435\u043e\u0431\u0440\u0430\u0437\u043e\u0432\u0430\u043d\u0438\u044f, \u0442\u0430\u043a\u0438\u0435 \u043a\u0430\u043a \u0430\u0432\u0442\u043e\u043c\u0430\u0442\u0438\u0447\u0435\u0441\u043a\u043e\u0435 \u0434\u0438\u0444\u0444\u0435\u0440\u0435\u043d\u0446\u0438\u0440\u043e\u0432\u0430\u043d\u0438\u0435, \u0432\u0435\u043a\u0442\u043e\u0440\u0438\u0437\u0430\u0446\u0438\u044f \u0438 \u043a\u043e\u043c\u043f\u0438\u043b\u044f\u0442\u043e\u0440 \u00ab\u0442\u043e\u0447\u043d\u043e \u0432 \u0441\u0440\u043e\u043a\u00bb.<\/p>\n<\/li>\n<li>\n<p>\u041e\u043f\u0440\u0435\u0434\u0435\u043b\u0435\u043d\u0438\u0435 \u043c\u043e\u0434\u0443\u043b\u0435\u0439, \u0441\u043b\u043e\u0435\u0432 \u0438 \u043c\u043e\u0434\u0435\u043b\u0435\u0439 \u0432\u043e \u0432\u0441\u0435\u0445 \u043d\u0438\u0445 \u043f\u0440\u0430\u043a\u0442\u0438\u0447\u0435\u0441\u043a\u0438 \u0438\u0434\u0435\u043d\u0442\u0438\u0447\u043d\u043e.<\/p>\n<\/li>\n<li>\n<p>Flax \u0438 JAX \u043f\u043e \u0434\u0438\u0437\u0430\u0439\u043d\u0443 \u0434\u043e\u0441\u0442\u0430\u0442\u043e\u0447\u043d\u043e \u0433\u0438\u0431\u043a\u0438\u0435 \u0438 \u0440\u0430\u0441\u0448\u0438\u0440\u044f\u0435\u043c\u044b\u0435.<\/p>\n<\/li>\n<li>\n<p>Flax \u043f\u043e\u043a\u0430 \u043d\u0435 \u0438\u043c\u0435\u0435\u0442 \u0432\u043e\u0437\u043c\u043e\u0436\u043d\u043e\u0441\u0442\u0438 \u0437\u0430\u0433\u0440\u0443\u0437\u043a\u0438 \u0438 \u043e\u0431\u0440\u0430\u0431\u043e\u0442\u043a\u0438 \u0434\u0430\u043d\u043d\u044b\u0445<\/p>\n<\/li>\n<li>\n<p>\u0427\u0442\u043e \u043a\u0430\u0441\u0430\u0435\u0442\u0441\u044f \u0433\u043e\u0442\u043e\u0432\u044b\u0445 \u0441\u043b\u043e\u0435\u0432 \u0438 \u043e\u043f\u0442\u0438\u043c\u0438\u0437\u0430\u0442\u043e\u0440\u043e\u0432, Flax \u043d\u0435 \u043d\u0443\u0436\u043d\u043e \u0437\u0430\u0432\u0438\u0434\u043e\u0432\u0430\u0442\u044c Tensorflow \u0438 Pytorch.  \u041a\u043e\u043d\u0435\u0447\u043d\u043e, \u0435\u043c\u0443 \u043d\u0435 \u0445\u0432\u0430\u0442\u0430\u0435\u0442 \u0442\u0430\u043a\u043e\u0439 \u0433\u0438\u0433\u0430\u043d\u0442\u0441\u043a\u043e\u0439 \u0431\u0438\u0431\u043b\u0438\u043e\u0442\u0435\u043a\u0438, \u043a\u0430\u043a \u0443 \u0435\u0433\u043e \u043a\u043e\u043d\u043a\u0443\u0440\u0435\u043d\u0442\u043e\u0432, \u043d\u043e \u043e\u043d \u043f\u043e\u0441\u0442\u0435\u043f\u0435\u043d\u043d\u043e \u0434\u043e\u0431\u0438\u0440\u0430\u0435\u0442\u0441\u044f \u0434\u043e \u043d\u0435\u0435.<\/p>\n<\/li>\n<\/ul>\n<div class=\"dl-prod-book-inline-banner\">\n<div class=\"dl-prod-book-inline-banner__image gatsby-image-wrapper\" style=\"position:relative;overflow:hidden\"><img decoding=\"async\" aria-hidden=\"true\" src=\"data:image\/png;base64,iVBORw0KGgoAAAANSUhEUgAAABQAAAAZCAYAAAAxFw7TAAAACXBIWXMAAAsTAAALEwEAmpwYAAAFzklEQVQ4y1WU2VOb1xnGzyfVbt1O6tST4AVjglc2s1sskgxC8GlBIAmQkMRugtlBbAJjwAKMDYbYYOyMSTJt46YzSV1PO2k7uWimnfauF532or3oTC\/7L3R6+ev7SaGeXjxzzrf9znOe9\/2OUtm1qAs3UVl2tOybmLJlvGBDy7Kina9JyZRSNWaRKVPGzCq0c5WYDJ21oE6Xp5VRhlI5TtR79WmdrkG9UykPqmQuOlONelc+OCUvn5KX3y5B\/UD0drG8Z9wrRZ0sFri8a4DPiNRlF2aRAQyPrxNff87U2gFtw6tU+keIjCWJTawzvvKY2Y0Dhha2GVnaQY\/FCQ3d5dbsfTFQgSlLzIh7pa75UBdFWW5qe1fpTuzROb\/PVX2Y0uA0ocltAiMbtE9sEZ3ZJTi6SfPQBs7uRapDs\/iH1sStwM7ZUedtAsz1425uIhDw4vU6OVdSj8ftoMxWj8\/bwPVqB5lFdgotdhx1VtwNVoJuK9fK7ehOGzeqqrHW1lNgcXDich1KK2zn4wUPP7zj4XfbLg4THkZ7mni55ObpjJsv111M9Xn40xOdtqAH3ethc9TFyzs6S+\/r7Md1\/rDVwMsFJyfzpB7HSjqo8wXJqGynJ9aC2+9nsLuZtrbmFLjR18xkvwe\/30NvxEul04tN99Ag4NsxF5FWF\/W6Tle7zndzddlySSeqMIa6HkXlRVD5HaIwKleUJ7oYxJwXSs8vNcv9AOpKi8xFl+X6klGDJimqV649AizrRSvrQZV2YyrtRCuJoRVHMZfIAvltnKrpZ2f2EYVTEr61H5PAtII2TPmtaHkBNKmBltuCKVfgRoFVxQCq\/JaoH1XWh7GAAVeFEU44pjmdPGQucR9HZBLL4h6ZvoS4bJXn4rgglFpU5QXTzgWuVOUQyjKIuvE+R3BzWT9V\/juci39IxsAqnzbO0FQywFnnAGUrh1wJJMWlRHO94w24oD0NVzVjqOpRVNUIpqrhFNTqWmR5ag9\/KMFaU4KvJp9i8cxhzumg0NZLT\/QhVyvku6IIWpFEcz3yBq7scZRtSvKZQKsZF\/AYbk+cj1pn+bN+l78PPqY3usj3yjo4W9xLpkDHJp7SPvYzcSTAUsm\/uEskxS2S4qraedTNWdEMJvu0wOP01Y\/yr+ga\/+zZpcEhkRRIPvkBsotDXL0Wxj\/5gqEP\/yKuelI7Mkv+ZqO4BlzV30E5FlB1CUx1afi9lkW+GNgkw3FbYEGOF4c5WRqmqCjM3PkAv50\/YPfVP8gqHybTcpu3JH\/zUWGVvopqWEY575KCO5fIk9yOX5B\/M8fD8Ypuvl8RJV968UVBN\/+Ww4CvP+WvyedU3JjkRPUI\/1dY5d1AuaXHXPdIwRu\/gVtkVeOczGkkp7yTr9sm4bND+M0hf5u7z4B1ireMzK2i6rFUUZVRVNX8AOXbRDUJ2LMucGkJgWvG3LnIaWlW50U3fzx4xn9+\/3O+jK4SqZrnlMSkORJoRv5G9rKAqpkQoH8L1fKQFPgI3iJjYBtTaI9vh5+SaRujtGaQYNMyDlkkuy2JCm+jWh+ljegr\/K8WKZgBaN5MSZO56UjGvcADLi39iNatz7jSucV3\/Gt8yye78Bm7WMakG7nPCUxc1o4bW06DlO9+esuNRpbiQJexRa7DmxzresC127uUDz\/h3S5x1pRMO7LLb2g1YIvfdIp0iQHSDJhrjY6t1+grP6XjyS9p2X2NY\/MLYge\/5uba57Q+ek388Cv6nv0Kz84rIvu\/IPrgc2pnDjmmL6XaLr1lo8qpSic5G9sls2+P4AeviG3\/BHviCRVy7NfM7BBc2Wf00QtCy7s4kp9gufdj3uvbJVt+Q5PkmgYmDKC0jCeZVuNKKuCs4cf0LD\/j4foMQ4tTDC8MMZ24JadOHwtLg1intzgzsifZiSPb7BtYastHMKNdjFECvzoubjY+YmL1OcH5PYpHdsjo2uKdzodcuLVNSXyf7DEBSlG0+oU3sLo5\/gubMZbW7M93hgAAAABJRU5ErkJggg==\" alt=\"\u041a\u043d\u0438\u0433\u0430 \u00ab\u0413\u043b\u0443\u0431\u043e\u043a\u043e\u0435 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u0435 \u0432 \u043f\u0440\u043e\u0438\u0437\u0432\u043e\u0434\u0441\u0442\u0432\u0435\u00bb\" style=\"position:absolute;top:0;left:0;width:100%;height:100%;object-fit:contain;object-position:center;opacity:1;transition-delay:500ms\"\/><noscript><picture><source srcset=\"https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/69585\/deep-learning-book-cover.png 200w,&#10;https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/497c6\/deep-learning-book-cover.png 400w,&#10;https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/3c17d\/deep-learning-book-cover.png 720w\" sizes=\"(max-width: 720px) 100vw, 720px\"\/><img decoding=\"async\" loading=\"lazy\" sizes=\"(max-width: 720px) 100vw, 720px\" srcset=\"https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/69585\/deep-learning-book-cover.png 200w,&#10;https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/497c6\/deep-learning-book-cover.png 400w,&#10;https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/3c17d\/deep-learning-book-cover.png 720w\" src=\"https:\/\/theaisummer.com\/static\/502e7c498dd9d981ac44c1dcd10f9276\/3c17d\/deep-learning-book-cover.png\" alt=\"\u041a\u043d\u0438\u0433\u0430 \u00ab\u0413\u043b\u0443\u0431\u043e\u043a\u043e\u0435 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u0435 \u0432 \u043f\u0440\u043e\u0438\u0437\u0432\u043e\u0434\u0441\u0442\u0432\u0435\u00bb\" style=\"position:absolute;top:0;left:0;opacity:1;width:100%;height:100%;object-fit:cover;object-position:center\"\/><\/picture><\/noscript><\/div>\n<div class=\"dl-prod-book-inline-banner__text\">\n<h2>\u041a\u043d\u0438\u0433\u0430 \u00ab\u0413\u043b\u0443\u0431\u043e\u043a\u043e\u0435 \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u0435 \u0432 \u043f\u0440\u043e\u0438\u0437\u0432\u043e\u0434\u0441\u0442\u0432\u0435\u00bb \ud83d\udcd6<\/h2>\n<h4>\u0423\u0437\u043d\u0430\u0439\u0442\u0435, \u043a\u0430\u043a \u0441\u043e\u0437\u0434\u0430\u0432\u0430\u0442\u044c, \u043e\u0431\u0443\u0447\u0430\u0442\u044c, \u0440\u0430\u0437\u0432\u0435\u0440\u0442\u044b\u0432\u0430\u0442\u044c, \u043c\u0430\u0441\u0448\u0442\u0430\u0431\u0438\u0440\u043e\u0432\u0430\u0442\u044c \u0438 \u043f\u043e\u0434\u0434\u0435\u0440\u0436\u0438\u0432\u0430\u0442\u044c \u043c\u043e\u0434\u0435\u043b\u0438 \u0433\u043b\u0443\u0431\u043e\u043a\u043e\u0433\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f.  \u0418\u0437\u0443\u0447\u0438\u0442\u0435 \u0438\u043d\u0444\u0440\u0430\u0441\u0442\u0440\u0443\u043a\u0442\u0443\u0440\u0443 \u043c\u0430\u0448\u0438\u043d\u043d\u043e\u0433\u043e \u043e\u0431\u0443\u0447\u0435\u043d\u0438\u044f \u0438 MLOps \u043d\u0430 \u043f\u0440\u0430\u043a\u0442\u0438\u0447\u0435\u0441\u043a\u0438\u0445 \u043f\u0440\u0438\u043c\u0435\u0440\u0430\u0445.<\/h4>\n<p>\u0423\u0437\u043d\u0430\u0442\u044c \u0431\u043e\u043b\u044c\u0448\u0435<\/p><\/div>\n<\/div>\n<p><em class=\"affiliate-disclosure\">* \u0420\u0430\u0441\u043a\u0440\u044b\u0442\u0438\u0435 \u0438\u043d\u0444\u043e\u0440\u043c\u0430\u0446\u0438\u0438: \u043e\u0431\u0440\u0430\u0442\u0438\u0442\u0435 \u0432\u043d\u0438\u043c\u0430\u043d\u0438\u0435, \u0447\u0442\u043e \u043d\u0435\u043a\u043e\u0442\u043e\u0440\u044b\u0435 \u0438\u0437 \u043f\u0440\u0438\u0432\u0435\u0434\u0435\u043d\u043d\u044b\u0445 \u0432\u044b\u0448\u0435 \u0441\u0441\u044b\u043b\u043e\u043a \u043c\u043e\u0433\u0443\u0442 \u0431\u044b\u0442\u044c \u043f\u0430\u0440\u0442\u043d\u0435\u0440\u0441\u043a\u0438\u043c\u0438 \u0441\u0441\u044b\u043b\u043a\u0430\u043c\u0438, \u0438 \u043c\u044b \u0431\u0435\u0437 \u0434\u043e\u043f\u043e\u043b\u043d\u0438\u0442\u0435\u043b\u044c\u043d\u044b\u0445 \u0437\u0430\u0442\u0440\u0430\u0442 \u0434\u043b\u044f \u0432\u0430\u0441 \u043f\u043e\u043b\u0443\u0447\u0438\u043c \u043a\u043e\u043c\u0438\u0441\u0441\u0438\u044e, \u0435\u0441\u043b\u0438 \u0432\u044b \u0440\u0435\u0448\u0438\u0442\u0435 \u0441\u043e\u0432\u0435\u0440\u0448\u0438\u0442\u044c \u043f\u043e\u043a\u0443\u043f\u043a\u0443 \u043f\u043e\u0441\u043b\u0435 \u043f\u0435\u0440\u0435\u0445\u043e\u0434\u0430 \u043f\u043e \u0441\u0441\u044b\u043b\u043a\u0435.<\/em><\/p>\n<\/div>\n","protected":false},"excerpt":{"rendered":"<p>\u041c\u043d\u0435 \u0431\u044b\u043b\u043e \u043e\u0447\u0435\u043d\u044c \u043b\u044e\u0431\u043e\u043f\u044b\u0442\u043d\u043e \u043f\u043e\u0441\u043c\u043e\u0442\u0440\u0435\u0442\u044c, \u043a\u0430\u043a JAX \u0441\u0440\u0430\u0432\u043d\u0438\u0432\u0430\u0435\u0442\u0441\u044f \u0441 Pytorch \u0438\u043b\u0438 Tensorflow. \u042f \u0440\u0435\u0448\u0438\u043b, \u0447\u0442\u043e \u043b\u0443\u0447\u0448\u0438\u0439 \u0441\u043f\u043e\u0441\u043e\u0431 \u0434\u043b\u044f \u043a\u043e\u0433\u043e-\u0442\u043e \u0441\u0440\u0430\u0432\u043d\u0438\u0442\u044c \u0444\u0440\u0435\u0439\u043c\u0432\u043e\u0440\u043a\u0438 \u2014 \u0441\u043e\u0437\u0434\u0430\u0442\u044c \u043e\u0434\u043d\u043e \u0438 \u0442\u043e \u0436\u0435 \u0441 \u043d\u0443\u043b\u044f \u0432 \u043e\u0431\u043e\u0438\u0445 \u0438\u0437 \u043d\u0438\u0445. \u0418 \u044d\u0442\u043e \u0438\u043c\u0435\u043d\u043d\u043e \u0442\u043e, \u0447\u0442\u043e \u044f \u0441\u0434\u0435\u043b\u0430\u043b. \u0412 \u044d\u0442\u043e\u0439 \u0441\u0442\u0430\u0442\u044c\u0435 \u044f \u0440\u0430\u0437\u0440\u0430\u0431\u0430\u0442\u044b\u0432\u0430\u044e \u0432\u0430\u0440\u0438\u0430\u0446\u0438\u043e\u043d\u043d\u044b\u0439 \u0430\u0432\u0442\u043e\u044d\u043d\u043a\u043e\u0434\u0435\u0440 \u0441 JAX, Tensorflow \u0438 Pytorch \u043e\u0434\u043d\u043e\u0432\u0440\u0435\u043c\u0435\u043d\u043d\u043e. \u042f \u043f\u0440\u0435\u0434\u0441\u0442\u0430\u0432\u043b\u044e [&hellip;]<\/p>\n","protected":false},"author":1,"featured_media":1268,"comment_status":"open","ping_status":"open","sticky":false,"template":"","format":"standard","meta":{"footnotes":""},"categories":[1],"tags":[],"class_list":{"0":"post-1267","1":"post","2":"type-post","3":"status-publish","4":"format-standard","5":"has-post-thumbnail","7":"category-ai-research-and-news"},"_links":{"self":[{"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/posts\/1267","targetHints":{"allow":["GET"]}}],"collection":[{"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/posts"}],"about":[{"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/types\/post"}],"author":[{"embeddable":true,"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/users\/1"}],"replies":[{"embeddable":true,"href":"https:\/\/gptmain.news\/index.php?rest_route=%2Fwp%2Fv2%2Fcomments&post=1267"}],"version-history":[{"count":0,"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/posts\/1267\/revisions"}],"wp:featuredmedia":[{"embeddable":true,"href":"https:\/\/gptmain.news\/index.php?rest_route=\/wp\/v2\/media\/1268"}],"wp:attachment":[{"href":"https:\/\/gptmain.news\/index.php?rest_route=%2Fwp%2Fv2%2Fmedia&parent=1267"}],"wp:term":[{"taxonomy":"category","embeddable":true,"href":"https:\/\/gptmain.news\/index.php?rest_route=%2Fwp%2Fv2%2Fcategories&post=1267"},{"taxonomy":"post_tag","embeddable":true,"href":"https:\/\/gptmain.news\/index.php?rest_route=%2Fwp%2Fv2%2Ftags&post=1267"}],"curies":[{"name":"wp","href":"https:\/\/api.w.org\/{rel}","templated":true}]}}