From 1ed5cb6f892d94a6f2261a8c14803324ddda9266 Mon Sep 17 00:00:00 2001 From: Shashank S Date: Thu, 29 Aug 2024 22:54:53 +0530 Subject: [PATCH 1/2] #235 Added Notebook Example for Torchscript --- notebooks/simple_model_scripted.pt | Bin 0 -> 3989 bytes notebooks/torchscript_example.ipynb | 92 ++++++++++++++++++++++++++++ tests/torchscript_example_test.py | 61 ++++++++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 notebooks/simple_model_scripted.pt create mode 100644 notebooks/torchscript_example.ipynb create mode 100644 tests/torchscript_example_test.py diff --git a/notebooks/simple_model_scripted.pt b/notebooks/simple_model_scripted.pt new file mode 100644 index 0000000000000000000000000000000000000000..54f81b11c612a0d37958a20d407430f800eafedf GIT binary patch literal 3989 zcmbtX3tUY39zQj`j51M~l3psM$2^*9hMI#SGD<2W+nuJFQ*$!S^Gtf^AtFgxYu&XT zn~<(*mqmM-W@zkEmt@h!wySWHl||aFvf-Z7SZ2};_p@g{zd1AK{QuwI?|gsH6T`xp z0Dy%B@ZK>MFau<8hFAcxGekT{z?N~Puo!`OWF80MkSP&i#z11?@qu!%^!Y^9@2ql_ z@WLJy0RMEl@O5-(s_A?%L7mIY$z_6n_?4<8`^{Ch*Xxuo_8WuZyPv8)-?&nB_Lw8M zZQVH~BZn=;S?}(*H^0c7aPIMUakuNDjwXkV}tVM8bs-4KU-CjTdY(D z%u-#(rK<9?qQE|MPt3i`%)IhauygZHaHEGDTxnaVn#B-AzLQZ*6AAu*@w=)-&^)Q64afF~AGS5HHk-Fdvkc zk3Fh4S1QZw<*I@JJEMuq21p;;ohr4wee za}h7GNH9XzP$c6LaOjyUmP_!g0A#6TJPYtSRHw(`)LQ1$?kw5@(5m+RY;g?8x zkQCz0gSiM^VmTugiIrF-pcw+Bz#N%m0v}b)H|J}hsS*Mo&oV_Jtw$hHKopLz*)@@G zRKhoomZwN2p<+YAoVrTTvJ*y)!US`wEt+`mXPy@jrpKvV)HCF_?jnh^Bf>y~-u-YC zE?d%?evV?PdOibLT-lE^Iqh}xCP(qP+tdY#QFgw~6MFBp1^b>LMy@T`UNpGp4(;pR zR}>_V@Os4+h7({-7v&!c+d1)$Nzaq=%OU+X@s}>xUL!1~)Go65p{#LB*Q0X-H#&-l z?W?Y)j{Hn_AB3Z~ZD zb_F*P57h;$FZRA>$g96*&g~DG;6(9Jc-k|{9`7p%QD+U}FiHzsT$3A7_qZFRlnQu$ zETtqK8gFUrk<@YfFuBZ9=92YufQ=Q)Yrj3^;C0tzvzvc#O_F6keMR8hGRriQ#R5jI zIxnbUcgM?&{o$RBEws?KPS+44AXGJ7vg2+=_KUL4tGrjwHe7p<+2>lgTRz3YH{`c2 z#5b}#zJ9nB`qf8$M3(kjtAKjb^rgp-zPA|4=e6pj+y?tpdXktP{-ktaas1xeDf9Pq z_RLzarJ>njQ43+g4$;ne!5i*GW@S57CFjT|Mprj6o(@Gkf{zjA?q9X?`YU=nc>hE~ zKuSTqw=Xp#@W7&ul>g+Pp66vfC;9uP$-ms=);_j6(FGLUCfsQ*34hRZY~9YyZFM))TSB8A&LZCVuf`lu`IjMqAZb$5v9UR!=S zirXe#_O~Qz6(LsbT;6Kcn!P!S(8`bR%JMl)h4Y9|&BNqET&#HCr*#E&2Ucvf-0R@D z_xP>zCja>Cq%{6oclEFzv4?c&HA8*wNogI`Bk0H%p{d80g#XLg>Ia*)yjp7%Hi$`1Mps8#`|JOx60oSNJ>O6-mlX z0{`@7CDLKDp2kAw?5Lty4VTlm6r6DwIP^`&OCyZ@{H7-OPp_%=zZ$suOpHmQ%>9PkG0Z>buV8JZvuL+qIQ)wleO5`+!~ktoF)e zbA`R-3zOZ^d%JuaA0F)}?u;G~J-!Za%14X4xyAQao28Rc|FGZqXzVG=9dOX=BcTl8 z2oag4joXdIeqv&w@r;qyl0hD~0&7#4=${vy0ID8snUGY5Ho%B5|FQMx>m)Qh_N+hd zOIwEZeoXXhchRdoB^!Z6Sm*t}^$TDG5kQy?#msvS*cgF{ejEG%eG?ODWUQ4zQkWxv zb2!*Iz=nArnxCljKpG`5Ado_%2T_A)G+Ge$n;INQVK8V6`b;{FF_RG(K*KBq)An5T zMy*^so;{~xH*^j|bNFZ^?a5IW$sUbF14z`-NZR9yvyMFn4I*IytHF9V z;Ar=0FCDCGbUGW`s`tj7(V*Hrh^hm63C)9#1btWHM#F1Y+1Wbq^U)yt7)\n" + ] + } + ], + "source": [ + "loaded_model = torch.jit.load(\"simple_model_scripted.pt\")\n", + "x = torch.randn(1, 10)\n", + "output = loaded_model(x)\n", + "print(output)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bot", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/torchscript_example_test.py b/tests/torchscript_example_test.py new file mode 100644 index 000000000..e5bf5adbd --- /dev/null +++ b/tests/torchscript_example_test.py @@ -0,0 +1,61 @@ +import torch +import unittest +import numpy as np + +class TestTorchScriptModel(unittest.TestCase): + + @classmethod + def setUpClass(cls): + # Load the TorchScript model + cls.model = torch.jit.load('notebooks/simple_model_scripted.pt') + cls.model.eval() # Set the model to evaluation mode + + def test_model_output_shape(self): + """Test if the model outputs the correct shape.""" + input_tensor = torch.randn(1, 5) # Adjust shape based on model input requirements + output_tensor = self.model(input_tensor) + self.assertEqual(output_tensor.shape, (1, 5), "Output shape mismatch") + + def test_model_output_values(self): + """Test if the model output values are within an expected range.""" + input_tensor = torch.randn(1, 5) + output_tensor = self.model(input_tensor) + # Example: Check if all output values are within the range -1 to 1 + self.assertTrue(torch.all(output_tensor >= -1) and torch.all(output_tensor <= 1), + "Output values out of expected range") + + def test_model_with_different_inputs(self): + """Test the model with various types of inputs to ensure robustness.""" + inputs = [ + torch.zeros(1, 5), + torch.ones(1, 5), + torch.randn(1, 5), + torch.full((1, 5), 0.5) + ] + for input_tensor in inputs: + output_tensor = self.model(input_tensor) + self.assertEqual(output_tensor.shape, (1, 5), "Output shape mismatch with different inputs") + + def test_model_gradients(self): + """Test if the model's gradients are computed correctly.""" + input_tensor = torch.randn(1, 5, requires_grad=True) + output_tensor = self.model(input_tensor) + output_tensor.sum().backward() + self.assertIsNotNone(input_tensor.grad, "Gradients were not computed") + + def test_scripted_model_serialization(self): + """Test if the scripted model can be reloaded and produce consistent outputs.""" + input_tensor = torch.randn(1, 5) + output_original = self.model(input_tensor) + + # Save and reload the scripted model + torch.jit.save(self.model, 'test_scripted_model.pt') + reloaded_model = torch.jit.load('test_scripted_model.pt') + reloaded_model.eval() + + output_reloaded = reloaded_model(input_tensor) + self.assertTrue(torch.allclose(output_original, output_reloaded), + "Outputs differ after reloading the scripted model") + +if __name__ == '__main__': + unittest.main() From 0b35a3b043b7fbc1a33db462ed67f4b35d68d99e Mon Sep 17 00:00:00 2001 From: Shashank S Date: Sat, 31 Aug 2024 17:55:16 +0530 Subject: [PATCH 2/2] Name Changed to Test --- .../{torchscript_example_test.py => test_torchscript_example.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{torchscript_example_test.py => test_torchscript_example.py} (100%) diff --git a/tests/torchscript_example_test.py b/tests/test_torchscript_example.py similarity index 100% rename from tests/torchscript_example_test.py rename to tests/test_torchscript_example.py