From 3ed0d8c972ba9a72a9ea45843d3a6789371be446 Mon Sep 17 00:00:00 2001 From: Maksim Novikov Date: Mon, 19 Oct 2020 21:19:29 +0200 Subject: [PATCH] Initial test for tensoflow model --- tests/conftest.py | 29 +++++++++++- tests/data/dummy_tensorflow/Dummy.model.yaml | 43 ++++++++++++++++++ tests/data/dummy_tensorflow/dummy.py | 12 +++++ .../dummy_tensorflow/model/saved_model.pb | Bin 0 -> 36289 bytes .../variables/variables.data-00000-of-00001 | Bin 0 -> 1253 bytes .../model/variables/variables.index | Bin 0 -> 275 bytes tests/test_server/test_reader.py | 6 +++ tiktorch/server/exemplum.py | 11 +++-- tiktorch/server/reader.py | 1 - 9 files changed, 97 insertions(+), 5 deletions(-) create mode 100644 tests/data/dummy_tensorflow/Dummy.model.yaml create mode 100644 tests/data/dummy_tensorflow/dummy.py create mode 100644 tests/data/dummy_tensorflow/model/saved_model.pb create mode 100644 tests/data/dummy_tensorflow/model/variables/variables.data-00000-of-00001 create mode 100644 tests/data/dummy_tensorflow/model/variables/variables.index diff --git a/tests/conftest.py b/tests/conftest.py index f7e1c447..e41d91ac 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,6 +17,7 @@ TEST_DATA = "data" TEST_PYBIO_ZIPFOLDER = "unet2d" TEST_PYBIO_DUMMY = "dummy" +TEST_PYBIO_TENSORFLOW_DUMMY = "dummy_tensorflow" NNModel = namedtuple("NNModel", ["model", "state"]) @@ -92,7 +93,6 @@ def pybio_model_bytes(data_path): return data - @pytest.fixture def pybio_model_zipfile(pybio_model_bytes): with ZipFile(pybio_model_bytes, mode="r") as zf: @@ -114,6 +114,33 @@ def pybio_dummy_model_bytes(data_path): return data +def archive(directory): + result = io.BytesIO() + + with ZipFile(result, mode="w") as zip_model: + def _archive(path_to_archive): + for path in path_to_archive.iterdir(): + if str(path.name).startswith("__"): + continue + + if path.is_dir(): + _archive(path) + + else: + with path.open(mode="rb") as f: + zip_model.writestr(str(path).replace(str(directory), ""), f.read()) + + _archive(directory) + + return result + + +@pytest.fixture +def pybio_dummy_tensorflow_model_bytes(data_path): + pybio_net_dir = Path(data_path) / TEST_PYBIO_TENSORFLOW_DUMMY + return archive(pybio_net_dir) + + @pytest.fixture def cache_path(tmp_path): return Path(getenv("PYBIO_CACHE_PATH", tmp_path)) diff --git a/tests/data/dummy_tensorflow/Dummy.model.yaml b/tests/data/dummy_tensorflow/Dummy.model.yaml new file mode 100644 index 00000000..4d1a5bf2 --- /dev/null +++ b/tests/data/dummy_tensorflow/Dummy.model.yaml @@ -0,0 +1,43 @@ +name: DummyTFModel +description: A dummy tensorflow model for testing +authors: + - ilastik team +cite: + - text: "Ilastik" + doi: https://doi.org +documentation: dummy.md +tags: [tensorflow] +license: MIT + +format_version: 0.1.0 +language: python +framework: tensorflow + +source: dummy.py::TensorflowModelWrapper + +test_input: null # ../test_input.npy +test_output: null # ../test_output.npy + +# TODO double check inputs/outputs +inputs: + - name: input + axes: bcyx + data_type: float32 + data_range: [-inf, inf] + shape: [1, 1, 128, 128] +outputs: + - name: output + axes: bcyx + data_type: float32 + data_range: [0, 1] + shape: + reference_input: input # FIXME(m-novikov) ignoring for now + scale: [1, 1, 1, 1] + offset: [0, 0, 0, 0] + #halo: [0, 0, 32, 32] # Should be moved to outputs + +prediction: + weights: + source: ./model + hash: {md5: TODO} + dependencies: conda:./environment.yaml diff --git a/tests/data/dummy_tensorflow/dummy.py b/tests/data/dummy_tensorflow/dummy.py new file mode 100644 index 00000000..d0d6fce3 --- /dev/null +++ b/tests/data/dummy_tensorflow/dummy.py @@ -0,0 +1,12 @@ +class TensorflowModelWrapper: + def __init__(self): + self._model = None + + def set_model(self, model): + self._model = model + + def forward(self, input_): + return self._model.predict(input_) + + def __call__(self, *args, **kwargs): + return self._model.predict(*args, **kwargs) diff --git a/tests/data/dummy_tensorflow/model/saved_model.pb b/tests/data/dummy_tensorflow/model/saved_model.pb new file mode 100644 index 0000000000000000000000000000000000000000..a252f80465013a3d5300f9470e6fa703b9878d5c GIT binary patch literal 36289 zcmeG_TWlLwb{bL?siXImWXso?dF9hgFDeR?4bI$6%mR*Q#|b z8Ov)gSC!-##O}zYxt8z^%YQ;)f$#2S~5*dPixwfN<(d}XxyD}##_HW~ilas(66DSo` z?HrrgBj*H!Qh(`sItW66^cU*o95xupRmz3p4jE9jTCu!CBcv}$Vc2*oRcjTk;&^uy zMwEL>uC7&T5*vnp$tLgx0fun?Ps`;zLjZLkAbA)Dv1d&#mX#X+%j6LlLVVVgyi}HV z9i6%emP8!#m`|RBF+@(~F8(yH+$rW1l?>}hFedZ}jX)m)mW-)N?T(U)i3>{$8FBg! z`*1C$Mh18@4BxDmLl8EAevDB<{j-$je3hcr>Kk6b|N&5e|<>Xpk)X_s~&I8&o z@Br{2LJ7ba?q*R|JD4JFEQ1aM*?xU`N1f9fCS3oZz??l=sHt(kPEpUq5?~GMnAEVb zn~AZfzddNO1!8cAe$<-o^B1E7TcgnI>M^*#301^o@U@Tqy zI7s>#g>hsE9(gEexX4gO8Hy zFnWD`>zVb9%a^3<>z7}=+*QwXe%1@K_BQV=U6kSXL+dSyifjxza%5*(#qL9UlH{EH$?Dwa_hBD`*s zI_u%4(=eHxAVZ1&v;Gm7&~l2jV-~;nYYKlG3D#SJ6dKwLWT$AO+^SBE8mdklVq#R7 zno5nBT22h2q9CA{JuKxarIM1wTjSYF08^87**vv_@z)RDY{zZAY%wcB+f`^NNH42;N1uu2~aMj z3<)HY3kup&O`$DhJvGE}YRn)c%5t$kE(mcaCF89q8H9lOri5c_piQEkW<;CcnuB9J z(U|nfq1MT0B#9SR;44kgD~3m}L#GQ|_;L5be{j%sj-l)1)(JT7hc4VmXgeUrE{f!5 ze^-Ei2KXHTW*tq9U67Y{D|w|PU05b#FvPL58G$UpqN@C0T`6ltxn#1sU0K_$m-Ibr zX-k!AW8=%@7OQjdjg@ud$Hkz3MQy>GzR49*ex)9LXHQ}tTDWhr~2Vt8qODLqUiHS!7A)lhbL?5{dA@@xbPQ#Q` zN8L|St4gj|C|WL*5yEd^copDR1=!@-tc@DUuoH6VNRiNSgUk?Z$ueqD<-G2ILKbTn zsZvJF5P$()t5lH>kHDb8m@}DW<4v+IVH<1&5yR5G1a?NCuZWMlAu(=4wv*rV6(7-93yU`T+9Z27U z<7eVyEtVZ=zT>l3kRLu5o*&K!@I&;88C(ngFbbc5C-y~GQ51PL=wx*CBVY+aTOyi+ zcR-rKMsHN{T`v>YjP>_gIE4 zEj47=(uJN3>&YWyOH)B}yc?7`U1zcmQ&bi;ODS##(e=iroK#IJ_4tW zmN(n?R{jc}mm1bWI2UConA-D8qQj#J1oSh46>nic*ziDFCjH4Kzut$VYt%U`{bd1; z86)}o)QCI!A{+I{gar@VC_zeLb;}rvaX@vJ#i0M@an@Yl5 zO9pikj(5PCeoBDpa7?KM1nse-!vUjG70j~_4+om{1#=(!2hTi>F6rk4IOM-QJ^ZE> zj>B`BnV>Q0kKw7c369NO#tm3GzB1fT27Mmsgm-?}c6+p^V{FZOJ8@9F3PMPqf*R$or{ zY<1tPg&_aM0om&RF%5r#b{wz&x&yV?24f+0nBAZhiuY#XE9o6pzwXZ6=Os8&E`joCOm}nVC3-6YUA)d8M|aTo*TQ zDY@I#O0le|^mPHIn3txQ+fsIvdr#I&C6r{jTg3M%<>%v!a!&udc4%%rC87EL@N;7Sj2`(gk@@ zhXyrK69)QO0S2xqER>X(f9b3rSsFSi9o2toe;+ryYCLh1$BdL>$ni}q5z$mIK%g4H zH9bq`SzIK9`|O}U9W$CY8}S7anTZ6Bh*=}{IP_xx<@zoQUa4UStr@x*04yIe98D36 z41*hS%zNm|3WA7z+r7~rf$3|kUnT#f8RNRBkK;@HPzf7;qxOkt^7)QK(N-RFy&>)g zbkfE~1F3Y*^)mx-`~9Rn&Bj9~?E~i{qMO6jJAFL@lh=?(T^ARPL^eH9%A$w641BoZLhi|U znPB+>;!TJwumIU`5K-J0K4O{pJfn-@y0i%wEUeQhUp$Lbe!$S0ubYJLolf5H0_TH5 z7OER8fVeXr^tOT(;fNx-5d%rRagV-Z+hht5AWS~k10+xI>Jpi>-!!m^0miQIN>pI7 zeRkTVk)4?BL*n~T&yS+nO^i5t90DM07I4XFmDy96wKESbQZlpii3q{CR?qHgE!;uJ2|Jx4%$j)1N^+`b6RbH4vrcRUF5Bp%rj?&*vpSTK#TgHGnD8~4qFTj`HR zTSdnP`A?krIfKR_UnZStL=2{v9_l!nUetSCvoc2f18h^hsCT|Tiltk&k#ojrwq4HR z*>}X?!9+6s>nQxb&^j;E;~+f&v$D#vl)1!8gX0*kQsjhwO){oGmAOcfJMzxYCDZvi zzOMIsy-_j;{h{thQii?}fkRtH=I95|1A_X$|KusAkWl-T+a(d;MhiAK!uvtcXbB#d zqU}G}cunk&PDkbOMiYlmt6|&GlXe65)EbT4=Ip%3Qo=Z|#^D}w75I29hB$EPj;~!k z*?=W$IN;Q5+C6q7>vhHh1l9{+K!_a6p}#1AXZaJ)6WV<1`Iz2ruzZ~gysK8gY!>gh zw|P^Xv-N#2k#P?h0uv5PpYEZ<#k=oAVT0!~)XzuY1}+3X(BtBThOY7Dg+lczn)Su# z&j=hgMGl(|f6c@V_o(PbaX&(K7`wrWmboEld6dATT%;I4Zu?r6IFM`{ij8eXLxN^W z;b3iO;jHcz@vM0B}Wb z@W+syq}#6fSRj~GYlz;9YbbUb-&WBaY2W_%-_8?m9Z~!Zkf`}-HPf>`dihGde5EQL zZ}^*Uyeja)us(h=2J3Lz!4r+lRX#Ww6P85rViTA2pGLr)j85DSsFiN}LC-$NlVo?+ zdC6D3opCj@#cxBE3t5sP*3QEIp5+p^qX*WL|8X6RXfR= z>$(HJTM2Wk&P4rQ!kk`Ix|cA=>P};EF|n60*W!qPMUuHP;d%*ktiZTk4dz<|a1T!M zN~vL8J9Fs8j6)^GTI|*l_`cHTS;eJ*WbD8sXG^zz-(}8s>%^~@Gw*kMK+rmCFSaF2 z;@Y1Kz$+o@+j`H|yR@nPu9HpW@Mc+i`~6ar;tZ=vabgP#!0~*}EH-SHS?vX`v=_K? zAZN76SmJBZ@nL!y##IcYwW~j8)8uG?tdKGoPaIg*OZOZ{jPZV)HWQJ>jd}x~iei!6+PEs8x2^ z%<$G4Y*CO)DjToy;FSOvb$w37Uwqs>!)I8D{LAbf4+aL;ksn0vA4`%}uVOBl#C}`E zEF*mg1Ma0y-p7Z2b5$|XNM(q$z^}-ucky z;O16G#=($q0SA^@4SxNW6MT>?;k7@qj$|xYE3u#ud}^*2Gv$tP9jxeG{1g%+>zV9R zr3zk1^kTVKsMK~{lU5ZSe+=^sPf0y3uX%SLQ)(5}*^U0wM*5$5O^XhTxfIQ8Kc6@H zTv=XU;~o(HcV?fsx>TJS8cvVz((}TP*k)xFC&>1CH=OAK1ovBp^1UJ>!THSn`=4Yu z)0wnhNZ1GY`Zl?SyV^Cp)v-AnJy64*999tCwZCDpXvy^i1(EqgWDDk=!XSDjc|L=G zmhKy*tQSjp7iVzqgKgXgl^Uko;GJJ=CVo4ys7gC{zrro4tY~*Dwc9v?z4|bT6Q!++ z7xLiaq`0uWkY*j-#LN_>%ywqd79{M2QJbFQ7R}fu*UZ|SP7d1RPUzaQ>Cxj(^T{4} z;-y7uCFu5O=3FSZDfJv)T1F{DZ#XnNVel^7~XhxDQ{C9>9He z!-+EvMshYZBf;B*0vO2~5%(tM^g@f3mlTM-4zwSO<1|>9PtL(1V?tErSKag7a>{uDAb2sRkQURCdvwUtYcobijmfid$eFCdkEj_ zHo3+t*gZwp_j#w&m(t*p^rv zA20NKNBC+KGvkI>p{@1W4s&ynjB(#OdE~9=85nq}UsXy4NW5f*mIAz#(9JY-ZnL+$ zvU(e@5{YAm9J{5R-8G6)0$A^=cKfdWGa=4 z--`BO#PF!`sPT*F+xVEkkX$ZTSY0smBNFy(Tu5Ob@K1vHZp&K~F)Qn{IM1JO&C|SB zp6DY_*fYRM=@2<)G0M#%0C-HVaRJX7`oNw_?M2VR>DMAj=P8l*2xvxte~5~IAMHX? z_iDNuX&GGrB)!^&POzl>`(X8cQxLx)gbTvKteQ7*4Zq=9!?WgIoqj7yynD6jzU=0T zHh1H<1@T|@UxfRinCiC#@o$8%LH!TKjJ_j?|0ZSV812rKcPdsxEFmC zPVYsNF#lR)*1G7Dcm|E4MACuP+a7%_G6O=|+YDg?U98>-$K9@U2E@#3k!jz-$?tWk z4oXo6?8AG}H4yiri(!Sfsp7PSD!iL`|D{_{F6MQRFYZN`;rw1S6@s9Vs=ySo_a0h| zbMFLNU43i|+OKpcC$;0%&^~@&ZReINV-I(*N;$ZK_fmLJu!wiZpm0PD2Ia>?V+D(c zePn+dp&L+O2303egu(@uYCOS*HQL@m;~M#zl)11~c=3ev(i0(!EVD0Vb3M@PA oj&PPpC-Mm^J?#I&(T~!w(n%LPOq7O{PFh$6?z#3s2wXn@4}#&T(*OVf literal 0 HcmV?d00001 diff --git a/tests/data/dummy_tensorflow/model/variables/variables.data-00000-of-00001 b/tests/data/dummy_tensorflow/model/variables/variables.data-00000-of-00001 new file mode 100644 index 0000000000000000000000000000000000000000..b117356723dc30625f7b7fd3d92d0aed1e271a58 GIT binary patch literal 1253 zcma)5%TC)s6tx3M#y5|`4Ukt`^8*l&5E}$m5K$EgHX)?lSwk5mqZ)%e=BaKgv1%8s zy6hiBYL_gUsOpDw*Kg>eQag@|nZe zZMHCM_MaOcw9B1x;>UofciSgl4tIXt^ap#_d&93kmks6f>w9Iw^Gdd&L|oOnS~CPtpj2cWvrkIC8oNKnSMffXytq)o}Kh zQz4llnmH52oE>M*1({nCGarrlC<=3K9CJR1sX!i*Qop(;UmB|Gp}Ma^l^_;vd2*9Z zD?^>(|}+&yYH4r}8m(8g8>8GK~V^UB6VT(30^HKqr} zX&#h<9+bl#L}n34CE!84T#Hl~x>8J95t0V9Yg v>Me5s(O}MjPYX9NFf!*d>;-ed1b0fZ{UJt<^Njo;W%xn(?}l!bQuo^cbzwP6 literal 0 HcmV?d00001 diff --git a/tests/test_server/test_reader.py b/tests/test_server/test_reader.py index 614462f2..8b14c0c4 100644 --- a/tests/test_server/test_reader.py +++ b/tests/test_server/test_reader.py @@ -21,3 +21,9 @@ def test_eval_model_zip(pybio_model_bytes, cache_path): with ZipFile(pybio_model_bytes) as zf: exemplum = eval_model_zip(zf, devices=["cpu"], cache_path=cache_path) assert isinstance(exemplum, Exemplum) + +@pytest.mark.xfail +def test_eval_tensorflow_model_zip(pybio_dummy_tensorflow_model_bytes, cache_path): + with ZipFile(pybio_dummy_tensorflow_model_bytes) as zf: + exemplum = eval_model_zip(zf, devices=["cpu"], cache_path=cache_path) + assert isinstance(exemplum, Exemplum) diff --git a/tiktorch/server/exemplum.py b/tiktorch/server/exemplum.py index fe58b25c..a5acee10 100644 --- a/tiktorch/server/exemplum.py +++ b/tiktorch/server/exemplum.py @@ -46,11 +46,10 @@ def __init__( pybio_model: nodes.Model, batch_size: int = 1, num_iterations_per_update: int = 2, - _devices=Sequence[torch.device], + _devices=Sequence[str], ): self.max_num_iterations = 0 self.iteration_count = 0 - self.devices = _devices spec = pybio_model.spec self.name = spec.name @@ -89,12 +88,18 @@ def __init__( self.halo = list(zip(self.output_axes, _halo)) self.model = get_instance(pybio_model) - self.model.to(self.devices[0]) if spec.framework == "pytorch": + self.devices = [torch.device(d) for d in _devices] + self.model.to(self.devices[0]) assert isinstance(self.model, torch.nn.Module) if spec.prediction.weights is not None: state = torch.load(spec.prediction.weights.source, map_location=self.devices[0]) self.model.load_state_dict(state) + # elif spec.framework == "tensorflow": + # import tensorflow as tf + # self.devices = [] + # tf_model = tf.keras.models.load_model(spec.prediction.weights.source) + # self.model.set_model(tf_model) else: raise NotImplementedError diff --git a/tiktorch/server/reader.py b/tiktorch/server/reader.py index 1f7b52c1..b265682c 100644 --- a/tiktorch/server/reader.py +++ b/tiktorch/server/reader.py @@ -38,7 +38,6 @@ def eval_model_zip(model_zip: ZipFile, devices: Sequence[str], cache_path: Optio pybio_model = spec.utils.load_model(spec_file_str, root_path=temp_path, cache_path=cache_path) - devices = [torch.device(d) for d in devices] if pybio_model.spec.training is None: return Exemplum(pybio_model=pybio_model, _devices=devices) else: