Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Init flyte integration #18

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions flyte/workflows/quant_workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import sys
import torch
from typing import List, Optional
from flytekit import task, workflow, Resources

sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

from src.config.config_loader import load_and_validate_config
from src.data.composer import DataComposer
from src.models.compose.composer import ModelComposer
from src.quantization.quantizer import Quantizer
from src.training.trainer import Trainer

torch.set_float32_matmul_precision('high')

PRECONFIGURED_CONFIGS = [
"config/rniq_config_resnet20.yaml",
]

# Flyte task to load and validate the config
@task(limits=Resources(cpu="1", mem="1Gi"))
def load_config(selected_config: Optional[str] = None) -> dict:
if selected_config:
config_path = selected_config
else:
config_path = PRECONFIGURED_CONFIGS[0]

return load_and_validate_config(config_path)

@task
def initialize_composer(config):
return ModelComposer(config=config)

@task
def initialize_quantizer(config):
return Quantizer(config=config)()

@task
def initialize_trainer(config):
return Trainer(config=config)

@task
def initialize_data_composer(config):
return DataComposer(config=config)

@task
def compose_data(data_composer):
return data_composer.compose()

@task
def compose_model(composer):
return composer.compose()

@task
def quantize_model(quantizer, model):
return quantizer.quantize(model, in_place=True)

@task
def test_model(trainer, model, data):
trainer.test(model, datamodule=data)

@task
def fit_model(trainer, model, data):
trainer.fit(model, datamodule=data)

@workflow
def model_quantization_workflow(selected_config: Optional[str] = None):
config = load_config(selected_config=selected_config)
composer = initialize_composer(config=config)
quantizer = initialize_quantizer(config=config)
trainer = initialize_trainer(config=config)
data_composer = initialize_data_composer(config=config)

data = compose_data(data_composer=data_composer)
model = compose_model(composer=composer)
qmodel = quantize_model(quantizer=quantizer, model=model)

test_model(trainer=trainer, model=qmodel, data=data)
fit_model(trainer=trainer, model=qmodel, data=data)
test_model(trainer=trainer, model=model, data=data)

if __name__ == "__main__":
torch.set_float32_matmul_precision('high')

selected_config = PRECONFIGURED_CONFIGS[0]

model_quantization_workflow(selected_config=selected_config)
10 changes: 3 additions & 7 deletions scripts/rniq_q_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))

from src.config.config_loader import load_and_validate_config
from src.data import CIFAR10DALIDataModule
from src.data import CIFAR10DataModule
from src.data.composer import DataComposer
from src.models.compose.composer import ModelComposer
from src.quantization.quantizer import Quantizer
from src.training.trainer import Trainer
Expand All @@ -17,12 +16,9 @@
composer = ModelComposer(config=config)
quantizer = Quantizer(config=config)()
trainer = Trainer(config=config)
data_composer = DataComposer(config=config)

# data = CIFAR10DALIDataModule()
data = CIFAR10DataModule()
data.batch_size = config.data.batch_size
data.num_workers = config.data.num_workers

data = data_composer.compose()
model = composer.compose()
qmodel = quantizer.quantize(model, in_place=True)

Expand Down
8 changes: 4 additions & 4 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .vision_cls.mnist import MNISTDataModule
from .vision_cls.cifar_dali import CIFAR10DALIDataModule
from .vision_cls.cifar import CIFAR10DataModule
__all__ = ["MNISTDataModule", "CIFAR10DataModule", "CIFAR10DALIDataModule"]
from .vision_cls.mnist import MNISTDataModule as MNIST
from .vision_cls.cifar_dali import CIFAR10DALIDataModule as CIFAR10_DALI
from .vision_cls.cifar import CIFAR10DataModule as CIFAR10
__all__ = ["MNIST", "CIFAR10", "CIFAR10_DALI"]
14 changes: 14 additions & 0 deletions src/data/composer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from lightning import pytorch as pl
from src import data as datamodules

class DataComposer():
def __init__(self, config=None) -> None:
self.config = config

def compose(self) -> pl.LightningDataModule:
data_config = self.config.data
datamodule = getattr(datamodules, data_config.dataset_name)()
datamodule.batch_size = data_config.batch_size
datamodule.num_workers = data_config.num_workers

return datamodule