diff --git a/errudite/builts/attribute.py b/errudite/builts/attribute.py index 4575b67..011d6d1 100644 --- a/errudite/builts/attribute.py +++ b/errudite/builts/attribute.py @@ -492,7 +492,10 @@ def create_from_json(cls, raw: Dict[str, str]) -> 'Attribute': Attribute The re-created attribute. """ - return Attribute(raw['name'], raw['description'], raw['cmd']) + return Attribute( + raw['name'] if "name" in raw else None, + raw['description'] if "description" in raw else None, + raw['cmd'] if "cmd" in raw else None) @staticmethod def create( diff --git a/errudite/builts/group.py b/errudite/builts/group.py index 3d9d300..1bc99a7 100644 --- a/errudite/builts/group.py +++ b/errudite/builts/group.py @@ -459,7 +459,10 @@ def create_from_json(cls, raw: Dict[str, str]) -> 'Group': Group The re-created group. """ - return Group(raw['name'], raw['description'], raw['cmd']) + return Group( + raw['name'] if "name" in raw else None, + raw['description'] if "description" in raw else None, + raw['cmd'] if "cmd" in raw else None) @staticmethod def create( diff --git a/errudite/predictors/predictor.py b/errudite/predictors/predictor.py index 0c3d6ad..74f3e15 100644 --- a/errudite/predictors/predictor.py +++ b/errudite/predictors/predictor.py @@ -118,10 +118,10 @@ def create_from_json(cls, raw: Dict[str, str]) -> 'Predictor': """ try: return Predictor.by_name(raw["model_class"])( - name=raw["name"], - description=raw["description"], - model_path=raw["model_path"], - model_online_path=raw["model_online_path"]) + name=raw["name"] if "name" in raw else None, + description=raw["description"] if "description" in raw else None, + model_path=raw["model_path"] if "model_path" in raw else None, + model_online_path=raw["model_online_path"] if "model_online_path" in raw else None) except: raise diff --git a/errudite/predictors/predictor_allennlp.py b/errudite/predictors/predictor_allennlp.py index d26b3ac..24f5dee 100644 --- a/errudite/predictors/predictor_allennlp.py +++ b/errudite/predictors/predictor_allennlp.py @@ -4,6 +4,7 @@ from allennlp.models.archival import load_archive from allennlp.predictors.predictor import Predictor as AllenPredictor from .predictor import Predictor +import torch # bidaf-model-2017.08.31.tar.gz # bidaf-model-2017.09.15-charpad.tar.gz @@ -37,10 +38,16 @@ def __init__(self, name: str, """ model = None if model_path: - archive = load_archive(model_path) + if torch.cuda.is_available(): + archive = load_archive(model_path, cuda_device=0) + else: + archive = load_archive(model_path) model = AllenPredictor.from_archive(archive, model_type) elif model_online_path: - model = AllenPredictor.from_path(model_online_path, model_type) + if torch.cuda.is_available(): + model = AllenPredictor.from_path(model_online_path, model_type, cuda_device=0) + else: + model = AllenPredictor.from_path(model_online_path, model_type) self.predictor = model Predictor.__init__(self, name, description, model, ['accuracy'])