Skip to content

Commit

Permalink
Merge pull request uwdata#4 from uwdata/master
Browse files Browse the repository at this point in the history
minor fix: 1. allennlp can be run with gpu, 2. check the dict entries…
  • Loading branch information
tongshuangwu authored Jul 23, 2019
2 parents 8f2dd50 + 506823f commit 34a9945
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 8 deletions.
5 changes: 4 additions & 1 deletion errudite/builts/attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,10 @@ def create_from_json(cls, raw: Dict[str, str]) -> 'Attribute':
Attribute
The re-created attribute.
"""
return Attribute(raw['name'], raw['description'], raw['cmd'])
return Attribute(
raw['name'] if "name" in raw else None,
raw['description'] if "description" in raw else None,
raw['cmd'] if "cmd" in raw else None)

@staticmethod
def create(
Expand Down
5 changes: 4 additions & 1 deletion errudite/builts/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,10 @@ def create_from_json(cls, raw: Dict[str, str]) -> 'Group':
Group
The re-created group.
"""
return Group(raw['name'], raw['description'], raw['cmd'])
return Group(
raw['name'] if "name" in raw else None,
raw['description'] if "description" in raw else None,
raw['cmd'] if "cmd" in raw else None)

@staticmethod
def create(
Expand Down
8 changes: 4 additions & 4 deletions errudite/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def create_from_json(cls, raw: Dict[str, str]) -> 'Predictor':
"""
try:
return Predictor.by_name(raw["model_class"])(
name=raw["name"],
description=raw["description"],
model_path=raw["model_path"],
model_online_path=raw["model_online_path"])
name=raw["name"] if "name" in raw else None,
description=raw["description"] if "description" in raw else None,
model_path=raw["model_path"] if "model_path" in raw else None,
model_online_path=raw["model_online_path"] if "model_online_path" in raw else None)
except:
raise

Expand Down
11 changes: 9 additions & 2 deletions errudite/predictors/predictor_allennlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from allennlp.models.archival import load_archive
from allennlp.predictors.predictor import Predictor as AllenPredictor
from .predictor import Predictor
import torch

# bidaf-model-2017.08.31.tar.gz
# bidaf-model-2017.09.15-charpad.tar.gz
Expand Down Expand Up @@ -37,10 +38,16 @@ def __init__(self, name: str,
"""
model = None
if model_path:
archive = load_archive(model_path)
if torch.cuda.is_available():
archive = load_archive(model_path, cuda_device=0)
else:
archive = load_archive(model_path)
model = AllenPredictor.from_archive(archive, model_type)
elif model_online_path:
model = AllenPredictor.from_path(model_online_path, model_type)
if torch.cuda.is_available():
model = AllenPredictor.from_path(model_online_path, model_type, cuda_device=0)
else:
model = AllenPredictor.from_path(model_online_path, model_type)
self.predictor = model
Predictor.__init__(self, name, description, model, ['accuracy'])

Expand Down

0 comments on commit 34a9945

Please sign in to comment.