مقدمة عن JAX لـ التعلم الآلي و NumPy
JAX عبارة عن مكتبة Python مصممة للحوسبة الرقمية عالية الأداء ،
وخاصة أبحاث التعلم الآلي. يعمل على تسريع كود Python و NumPy باستخدام GPU.
لقد ظهرت JAX في مجال التعلم الآلي بطموحات لجعل التعلم الآلي بسيطًا وفعالًا.
لا يزال JAX مشروعًا بحثيًا من Google و Deepmind وليس بعد
منتجًا رسميًا من Google ولكن تم استخدامه على نطاق واسع داخليًا واعتمده باحثو
ML الخارجيون. أردنا تقديم مقدمة حول JAX وكيفية تثبيت JAX ومزاياه وإمكانياته.
ما هو JAX للتعلم الآلي ؟
JAX عبارة عن مكتبة Python مصممة للحوسبة الرقمية عالية الأداء ،
وخاصة أبحاث التعلم الآلي. تعتمد واجهة برمجة التطبيقات الخاصة
بها للوظائف العددية على NumPy ، وهي مجموعة من الوظائف المستخدمة
في الحوسبة العلمية. يركز JAX على تسريع عملية التعلم الآلي باستخدام XLA
لتجميع وظائف NumPy على وحدات معالجة الرسومات ويستخدم autograd
للتمييز بين وظائف Python و NumPy بالإضافة إلى التحسين المستند إلى التدرج.
JAX قادر على التمييز من خلال الحلقات ، والفروع ، والعودية ، والإغلاق ،
وأخذ مشتقات من مشتقات بسهولة باستخدام تسريع GPU. يدعم JAX أيضًا
backpropagation والتمايز في الوضع الأمامي.
تقدم JAX أداءً فائقًا عند استخدام وحدات معالجة الرسومات لتشغيل
التعليمات البرمجية الخاصة بك وخيار تجميع في الوقت المناسب (JIT) لتسريع
المشاريع الكبيرة بسهولة ، والتي سنتعمق فيها لاحقًا في هذه المقالة.
فكر في JAX كمكتبة Python تقوم بتعديل كود NumPy و Python
بتحويلات وظيفية لتمكين التعلم الآلي المتسارع. كقاعدة عامة ، يجب عليك
استخدام JAX عندما تخطط للتدريب باستخدام وحدات معالجة الرسومات أو
حساب التدرجات اللونية (autograd) أو استخدام تجميع كود JIT.
لماذا أستخدم JAX ؟
بالإضافة إلى العمل مع وحدات المعالجة المركزية العادية ، فإن الوظيفة
الرئيسية لـ JAX هي القدرة على العمل بكامل طاقته مع وحدات معالجة
مختلفة مثل وحدات معالجة الرسومات. يمنح هذا JAX ميزة كبيرة على الحزم
المماثلة لأن استخدام موازاة GPU يتيح أداءً أسرع من وحدات المعالجة المركزية
عندما يتعلق الأمر بمعالجة الصور والمتجهات.
هذا مهم للغاية لأنه عند استخدام مكتبة NumPy ، يمكن للمستخدمين إنشاء مصفوفات
بأحجام استثنائية تسمح لوحدات معالجة الرسومات أن تكون أكثر كفاءة في
الوقت عند معالجة تنسيقات البيانات هذه.
100 ضعف السرعة والأداء من خلال اثنين من التطبيقات الرئيسية:
1- Vectorization -
معالجة بيانات متعددة كتعليمات فردية توفر تسريعًا كبيرًا لحسابات الجبر الخطي والتعلم الآلي
2- موازاة الكود -
عملية أخذ الكود التسلسلي الذي يعمل على معالج واحد وتوزيعه. تُفضل وحدات معالجة
الرسومات هنا نظرًا لوجود العديد من المعالجات المتخصصة في العمليات الحسابية.
3- التفاضل التلقائي -
تفاضل بسيط للغاية ومباشر يمكن ربطه عدة مرات لتقييم المشتقات ذات الترتيب الأعلى بسهولة.
كيفية تثبيت JAX ؟
لتثبيت إصدار JAX الخاص بوحدة المعالجة المركزية فقط ، والذي قد يكون
مفيدًا لإجراء التطوير المحلي على جهاز كمبيوتر محمول ، يمكنك تشغيله :
Shell
pip install --upgrade pip
pip install --upgrade "jax[cpu]"
في نظام Linux ، غالبًا ما يكون من الضروري تحديث pip
أولاً إلى إصدار يدعم العديد من عجلاتlinux2014.
1- تركيب النقطة: GPU (CUDA)
لتثبيت JAX مع دعم كل من CPU و NVIDIA GPU ، يجب عليك أولاً
أنظمة التعلم العميق الشائعة الأخرى ، فإن JAX لا تجمع CUDA أو
CuDNN كجزء من حزمة النقطة.
توفر JAX عجلات مسبقة الصنع متوافقة مع CUDA لنظام Linux فقط ،
مع CUDA 11.1 أو أحدث ، و CuDNN 8.0.5 أو أحدث. مجموعات أخرى من
نظام التشغيل ، CUDA ، و CuDNN ممكنة ، ولكنها تتطلب البناء من المصدر:
- مطلوب كودا 11.1 أو أحدث
قد تتمكن من استخدام إصدارات CUDA الأقدم إذا كنت تقوم بالبناء من المصدر ،
ولكن هناك أخطاء معروفة في CUDA في جميع إصدارات CUDA الأقدم من 11.1 ،
لذلك نحن لا نشحن الثنائيات سابقة الإنشاء لإصدارات CUDA الأقدم.
- إصدارات cuDNN المدعومة للعجلات مسبقة الصنع هي:
*cuDNN 8.2 أو أحدث. نوصي باستخدام عجلة cuDNN 8.2 إذا كان تثبيت cuDNN
الخاص بك جديدًا بما يكفي لأنه يدعم وظائف إضافية.
* cuDNN 8.0.5 أو أحدث.
- يجب عليك استخدام إصدار برنامج تشغيل NVIDIA
يكون جديدًا على الأقل مثل إصدار برنامج التشغيل المقابل لمجموعة أدوات CUDA .
على سبيل المثال ، إذا كان لديك تحديث 4 لـ CUDA 11.4 مثبتًا ، فيجب عليك استخدام
برنامج تشغيل NVIDIA 470.82.01 أو أحدث إذا كان نظام التشغيل Linux.
هذا مطلب صارم موجود لأن JAX تعتمد على كود تجميع JIT ؛ قد يؤدي كبار السن من السائقين إلى الفشل.
إذا كنت بحاجة إلى استخدام مجموعة أدوات CUDA أحدث مع برنامج تشغيل أقدم ،
على سبيل المثال في مجموعة حيث لا يمكنك تحديث برنامج تشغيل NVIDIA بسهولة ،
فقد تتمكن من استخدام حزم التوافق مع التوجيهات CUDA التي توفرها NVIDIA لهذا الغرض:
Shell
pip install --upgrade pip
# Installs the wheel compatible with CUDA 11 and cuDNN 8.2 or newer.
# Note: wheels are only available on Linux.
pip install --upgrade "jax[cuda]" https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
يجب أن يتوافق إصدار jaxlib مع إصدار تثبيت CUDA الحالي الذي تريد استخدامه.
يمكنك تحديد إصدار CUDA و CuDNN معين لـ jaxlib بشكل صريح:
Shell
pip install --upgrade pip
# Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Installs the wheel compatible with Cuda >= 11.1 and cudnn >= 8.0.5
pip install "jax[cuda11_cudnn805]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
يمكنك العثور على إصدار CUDA الخاص بك باستخدام الأمر:
Shell
nvcc --version
تتوقع بعض وظائف GPU أن يكون تثبيت
CUDA على / usr / local / cuda-XX ، حيث يجب استبدال XX
برقم إصدار CUDA (مثل cuda-11.1). إذا تم تثبيت CUDA في مكان
آخر على نظامك ، فيمكنك إما إنشاء ارتباط رمزي :
Shell
sudo ln -s /path/to/cuda /usr/local/cuda-X.X
مقارنة JAX بـ NumPy
نظرًا لأن JAX عبارة عن NumPy معزز ، فإن تركيبها متشابه جدًا ،
مما يمنح المستخدمين القدرة على استخدام الاثنين بالتبادل في المشاريع
التي لا تعمل فيها NumPy أو JAX. هذا غالبًا مع المشاريع الصغيرة حيث
يكون مقدار التسارع ضئيلًا في الوقت الذي يتم توفيره. ومع ذلك ،
مع زيادة حجم النماذج ، يجب أن تفكر في JAX.
*ضرب مصفوفتين باستخدام JAX مقابل NumPy
لتوضيح فرق السرعة بين هاتين المكتبتين بوضوح ، سنستخدم
كلاهما لضرب مصفوفتين في بعضهما البعض ثم التحقق من اختلافات الأداء
بين وحدة المعالجة المركزية فقط ووحدة معالجة الرسومات.
سوف نتحقق أيضًا من تعزيز الأداء الناتج عن مترجم JIT.
لمتابعة هذا البرنامج التعليمي ، قم بتثبيت واستيراد مكتبات JAX و NumPy (من الخطوة السابقة).
يمكنك اختبار الكود الخاص بك على مواقع مثل Kaggle أو Google Colab .
كما هو الحال مع أي مكتبة ، يجب عليك استيراد JAX عن طريق
كتابة الأسطر التالية في بداية التعليمات البرمجية الخاصة بك :
Python
import jax.numpy as jnp
from jax import random
يمكنك أيضًا استيراد مكتبة NumPy بطريقة مماثلة:
Python
import numpy as np
بعد ذلك ، سنقارن أداء كل من JAX و Numpy باستخدام CPU و GPU
بضرب مصفوفتين معًا في Python. بالنسبة لهذه المعايير ، فإن الأقل هو الأفضل.
*NumPy على وحدة المعالجة المركزية
للبدء ، سننشئ مصفوفة من 5000 إلى 5000 باستخدام NumPy ونختبر أداءها من حيث السرعة.
Python
import numpy as np
size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
استغرقت حلقة واحدة من الكود الذي يعمل على NumPy حوالي 750 مللي ثانية لكل حلقة للتشغيل.
*JAX على وحدة المعالجة المركزية
لنقم الآن بتشغيل نفس الكود ، لكن هذه المرة باستخدام مكتبة JAX.
Python
import jax.numpy as jnp
size = 5000
x = jnp.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
كما ترى ، تظهر مقارنة أداء JAX و NumPy CPU فقط أن
NumPy هو الخيار الأسرع. على الرغم من أن JAX قد لا يوفر أفضل
أداء مع وحدات المعالجة المركزية العادية ، إلا أنه يوفر أداءً أفضل بكثير مع وحدات معالجة الرسومات.
785 ms per loop
*JAX مع GPU
الآن ، دعنا نحاول إنشاء مصفوفة 5000 × 5000 نفسها ، هذه المرة باستخدام
JAX مع وحدة معالجة الرسومات بدلاً من وحدة المعالجة المركزية العادية:
Python
import jax
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
size = 5000
x = random.normal(key, (size, size)).astype(jnp.float32)
%time x_jax = jax.device_put(x)
%time jnp.dot(x_jax, x_jax.T).block_until_ready()
%timeit jnp.dot(x_jax, x_jax.T).block_until_ready()
80.6 مللي ثانية لكل حلقة
كما هو موضح بوضوح عند تشغيل JAX على وحدة معالجة الرسومات بدلاً
من وحدة المعالجة المركزية ، نحقق وقتًا أفضل بكثير يبلغ حوالي 80
مللي ثانية لكل حلقة (حوالي 15 ضعف الأداء). سيكون من الأسهل
رؤية ذلك عند استخدام مصفوفات أكبر أو مقاييس زمنية.
* تجميع في الوقت المناسب (JIT)
باستخدام الأمر jit ، سيتم تجميع الكود الخاص بنا باستخدام
مترجم XLA محدد ، مما يسمح بتنفيذ وظائفنا بكفاءة.
تستخدم مكتبات مثل JAX و Tensorflow ، اختصارًا للجبر الخطي المتسارع ،
لترجمة وتشغيل التعليمات البرمجية على وحدة معالجة الرسومات بكفاءة أكبر.
لتلخيص ذلك ، فإن XLA عبارة عن مترجم جبر خطي محدد قادر على تجميع التعليمات البرمجية بسرعة أعلى بكثير.
سنختبر الكود الخاص بنا باستخدام وظيفة selu_np ، والتي تعني
الوحدة الخطية الأسية المتدرجة ، ونتحقق من أداء الوقت المختلف بين NumPy على
وحدة المعالجة المركزية العادية ، وتشغيل JAX على وحدة معالجة الرسومات مع JIT:
Python
def selu_np(x, alpha=1.67, lmbda=1.05):
return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)
def selu_jax(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
* NumPy على وحدة المعالجة المركزية
للبدء ، سننشئ متجهًا بحجم 1،000،000 باستخدام مكتبة NumPy:
Python
import numpy as np
x = np.random.normal(size=(1000000,)).astype(np.float32)
%timeit selu_np(x)
8.3 مللي ثانية لكل حلقة
*JAX على GPU مع JIT
الآن سنختبر الكود الخاص بنا أثناء استخدام JAX و JIT على وحدة معالجة الرسومات:
Python
import jax
import jax.numpy as jnp
from jax import random
from jax import grad, jit
key = random.PRNGKey(0)
def selu_np(x, alpha=1.67, lmbda=1.05):
return lmbda * np.where(x > 0, x, alpha * np.exp(x) - alpha)
def selu_jax(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000000,))
selu_jax_jit = jit(selu_jax)
%time x_jax = jax.device_put(x)
%time selu_jax_jit(x_jax).block_until_ready()
%timeit selu_jax_jit(x_jax).block_until_ready()
153 ميكرو ثانية لكل حلقة (0.153 ميلي ثانية لكل حلقة)
أخيرًا ، عند استخدام مترجم JIT مع وحدة معالجة الرسومات ، نحصل
على أداء أفضل بكثير من استخدام وحدة معالجة الرسومات العادية.
كما يمكنك بوضوح ، ترى الفرق واضحًا جدًا ، زيادة السرعة بنسبة 5000٪
تقريبًا أو أسرع 50 مرة من NumPy إلى JAX مع JIT!
فكر في JAX على أنه تعديل لـ NumPy لتمكين التعلم الآلي المتسارع
باستخدام وحدات معالجة الرسومات. نظرًا لأنه لا يمكن ترجمة NumPy إلا
إلى وحدة المعالجة المركزية ، فإن JAX يكون أسرع من NumPy إذا
اخترت تنفيذ التعليمات البرمجية على وحدات معالجة الرسومات. كقاعدة عامة ،
يجب عليك استخدام JAX عندما تخطط لاستخدام NumPy مع وحدات
معالجة الرسومات أو استخدام تجميع كود JIT.
حدود JAX: وظائف خالصة
تم تصميم تحويلات وتعقيدات JAX لوظائف Python النقية وظيفيًا.
لا يمكن للوظائف النقية تغيير حالة البرنامج عن طريق الوصول إلى المتغيرات
الخارجية ، ولا يمكن أن يكون لها آثار جانبية على وظائف مثل تدفقات الإدخال / الإخراج مثل print ().
تؤدي الجولات المتتالية إلى عدم أداء هذه الآثار الجانبية على النحو المنشود.
إذا لم تكن حريصًا ، فقد تؤدي الآثار الجانبية التي لم يتم تتبعها إلى إقصاء دقة حساباتك المقصودة.
باستخدام جوجل JAX
في هذه المقالة ، شرحنا إمكانيات JAX والمزايا التي توفرها لـ NumPy.
تناولنا كيفية تثبيت مكتبة JAX ومزاياها للتعلم الآلي.
ثم انتقلنا إلى استيراد JAX و NumPy. علاوة على ذلك ، قمنا بمقارنة JAX
مع NumPy (وهي أشهر مكتبة منافسة هناك) وكشفنا عن فروق الوقت والأداء
بين هذين باستخدام وحدات المعالجة المركزية ووحدات معالجة الرسومات
العادية جنبًا إلى جنب مع بعض اختبارات JIT أيضًا وشهدنا تحسينات جذرية في السرعة.
إذا كنت ممارسًا متقدمًا في تعلم الآلة / التعلم العميق ، فإن إضافة مكتبة مثل
JAX إلى ترسانتك باستخدام مسرعات (GPU / TPU)
ومترجم JIT الفعال سيجعل الحياة أسهل كثيرًا بالتأكيد.