Skip to content

Commit

Permalink
addressing the PR comments for the script: classification_with_grn_an…
Browse files Browse the repository at this point in the history
…d_vsn.py
  • Loading branch information
Humbulani1234 committed Jan 13, 2025
1 parent 6399fe2 commit 2aaadce
Show file tree
Hide file tree
Showing 3 changed files with 1,106 additions and 1,368 deletions.
73 changes: 12 additions & 61 deletions examples/structured_data/classification_with_grn_and_vsn.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,6 @@
valid_data.to_csv(valid_data_file, index=False, header=False)
test_data.to_csv(test_data_file, index=False, header=False)

"""
Clean the directory for the downloaded files except the .tar.gz file and
also remove the empty directories
"""

subprocess.run(
f'find {extracted_path} -type f ! -name "*.tar.gz" -exec rm -f {{}} +',
shell=True,
check=True,
)
subprocess.run(
f"find {extracted_path} -type d -empty -exec rmdir {{}} +", shell=True, check=True
)

"""
## Define dataset metadata
Expand Down Expand Up @@ -337,6 +323,10 @@ def __init__(self, units):
def call(self, inputs):
return self.linear(inputs) * self.sigmoid(inputs)

# Remove build warnings
def build(self):
self.built = True


"""
## Implement the Gated Residual Network
Expand Down Expand Up @@ -372,6 +362,10 @@ def call(self, inputs):
x = self.layer_norm(x)
return x

# Remove build warnings
def build(self):
self.built = True


"""
## Implement the Variable Selection Network
Expand Down Expand Up @@ -446,52 +440,9 @@ def call(self, inputs):
for idx, input in enumerate(concat_inputs):
x.append(self.grns[idx](input))
x = keras.ops.stack(x, axis=1)

# The reason for each individual backend calculation is that I couldn't find
# the equivalent keras operation that is backend-agnostic. In the following case there,s
# a keras.ops.matmul but it was returning errors. I could have used the tensorflow matmul
# for all backends, but due to jax jit tracing it results in an error.
def matmul_dependent_on_backend(tensor_1, tensor_2):
"""
Function for executing matmul for each backend.
"""
# jax backend
if keras.backend.backend() == "jax":
import jax.numpy as jnp

result = jnp.sum(tensor_1 * tensor_2, axis=1)
elif keras.backend.backend() == "torch":
result = torch.sum(tensor_1 * tensor_2, dim=1)
# tensorflow backend
elif keras.backend.backend() == "tensorflow":
result = keras.ops.squeeze(tf.matmul(tensor_1, tensor_2, transpose_a=True), axis=1)
# unsupported backend exception
else:
raise ValueError(
"Unsupported backend: {}".format(keras.backend.backend())
)
return result

# jax backend
if keras.backend.backend() == "jax":
# This repetative imports are intentional to force the idea of backend
# separation
import jax.numpy as jnp

result_jax = matmul_dependent_on_backend(v, x)
return result_jax
# torch backend
if keras.backend.backend() == "torch":
import torch

result_torch = matmul_dependent_on_backend(v, x)
return result_torch
# tensorflow backend
if keras.backend.backend() == "tensorflow":
import tensorflow as tf

result_tf = keras.ops.squeeze(tf.matmul(v, x, transpose_a=True), axis=1)
return result_tf
return keras.ops.squeeze(
keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1
)

# to remove the build warnings
def build(self):
Expand Down Expand Up @@ -520,7 +471,7 @@ def create_model(encoding_size):
learning_rate = 0.001
dropout_rate = 0.15
batch_size = 265
num_epochs = 1 # maybe adjusted to a desired value
num_epochs = 1 # may be adjusted to a desired value
encoding_size = 16

model = create_model(encoding_size)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,34 +282,6 @@
"test_data.to_csv(test_data_file, index=False, header=False)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"Clean the directory for the downloaded files except the .tar.gz file and\n",
"also remove the empty directories"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"subprocess.run(\n",
" f'find {extracted_path} -type f ! -name \"*.tar.gz\" -exec rm -f {{}} +',\n",
" shell=True,\n",
" check=True,\n",
")\n",
"subprocess.run(\n",
" f\"find {extracted_path} -type d -empty -exec rmdir {{}} +\", shell=True, check=True\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -505,6 +477,10 @@
"\n",
" def call(self, inputs):\n",
" return self.linear(inputs) * self.sigmoid(inputs)\n",
"\n",
" # Remove build warnings\n",
" def build(self):\n",
" self.built = True\n",
""
]
},
Expand Down Expand Up @@ -554,6 +530,10 @@
" x = inputs + self.gated_linear_unit(x)\n",
" x = self.layer_norm(x)\n",
" return x\n",
"\n",
" # Remove build warnings\n",
" def build(self):\n",
" self.built = True\n",
""
]
},
Expand Down Expand Up @@ -642,52 +622,9 @@
" for idx, input in enumerate(concat_inputs):\n",
" x.append(self.grns[idx](input))\n",
" x = keras.ops.stack(x, axis=1)\n",
"\n",
" # The reason for each individual backend calculation is that I couldn't find\n",
" # the equivalent keras operation that is backend-agnostic. In the following case there,s\n",
" # a keras.ops.matmul but it was returning errors. I could have used the tensorflow matmul\n",
" # for all backends, but due to jax jit tracing it results in an error.\n",
" def matmul_dependent_on_backend(thsi, v):\n",
" \"\"\"\n",
" Function for executing matmul for each backend.\n",
" \"\"\"\n",
" # jax backend\n",
" if keras.backend.backend() == \"jax\":\n",
" import jax.numpy as jnp\n",
"\n",
" result = jnp.sum(thsi * v, axis=1)\n",
" elif keras.backend.backend() == \"torch\":\n",
" result = torch.sum(thsi * v, dim=1)\n",
" # tensorflow backend\n",
" elif keras.backend.backend() == \"tensorflow\":\n",
" result = keras.ops.squeeze(tf.matmul(thsi, v, transpose_a=True), axis=1)\n",
" # unsupported backend exception\n",
" else:\n",
" raise ValueError(\n",
" \"Unsupported backend: {}\".format(keras.backend.backend())\n",
" )\n",
" return result\n",
"\n",
" # jax backend\n",
" if keras.backend.backend() == \"jax\":\n",
" # This repetative imports are intentional to force the idea of backend\n",
" # separation\n",
" import jax.numpy as jnp\n",
"\n",
" result_jax = matmul_dependent_on_backend(v, x)\n",
" return result_jax\n",
" # torch backend\n",
" if keras.backend.backend() == \"torch\":\n",
" import torch\n",
"\n",
" result_torch = matmul_dependent_on_backend(v, x)\n",
" return result_torch\n",
" # tensorflow backend\n",
" if keras.backend.backend() == \"tensorflow\":\n",
" import tensorflow as tf\n",
"\n",
" result_tf = keras.ops.squeeze(tf.matmul(v, x, transpose_a=True), axis=1)\n",
" return result_tf\n",
" return keras.ops.squeeze(\n",
" keras.ops.matmul(keras.ops.transpose(v, axes=[0, 2, 1]), x), axis=1\n",
" )\n",
"\n",
" # to remove the build warnings\n",
" def build(self):\n",
Expand Down Expand Up @@ -744,7 +681,7 @@
"learning_rate = 0.001\n",
"dropout_rate = 0.15\n",
"batch_size = 265\n",
"num_epochs = 1 # maybe adjusted to a desired value\n",
"num_epochs = 1 # may be adjusted to a desired value\n",
"encoding_size = 16\n",
"\n",
"model = create_model(encoding_size)\n",
Expand Down
Loading

0 comments on commit 2aaadce

Please sign in to comment.