Skip to content

Commit

Permalink
Merge pull request #58 from remix/dmitriy/unicode-fixes
Browse files Browse the repository at this point in the history
unicode fixes for loading and writing feeds
  • Loading branch information
invisiblefunnel authored Aug 8, 2019
2 parents 0ba80fa + 468282f commit 4c64963
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 33 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ target/

# virtual environment
venv/
.venv/

scratch/
.DS_Store
Expand Down
34 changes: 26 additions & 8 deletions partridge/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,37 @@ def remove_node_attributes(G: nx.DiGraph, attributes: Union[str, Iterable[str]])


def detect_encoding(f: BinaryIO, limit: int = 2500) -> str:
"""
Return encoding of provided input stream.
Most of the time it's unicode, but if we are unable to decode the input
natively, use `chardet` to determine the encoding heuristically.
"""
unicode_decodable = True

for line_no, line in enumerate(f):
try:
line.decode("utf-8")
except UnicodeDecodeError:
unicode_decodable = False
break

if line_no > limit:
break

if unicode_decodable:
return "utf-8"

f.seek(0)
u = UniversalDetector()
for line in f:
u.feed(line)

limit -= 1
if u.done or limit < 1:
for line_no, line in enumerate(f):
u.feed(line)
if u.done or line_no > limit:
break

u.close()
if u.result["encoding"].lower() == "ascii":
return "utf-8"
else:
return u.result["encoding"]
return u.result["encoding"]


def empty_df(columns: Optional[Iterable[str]] = None) -> pd.DataFrame:
Expand Down
5 changes: 1 addition & 4 deletions partridge/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ def write_feed_dangerously(
your own risk.
"""
nodes = DEFAULT_NODES if nodes is None else nodes
try:
tmpdir = tempfile.mkdtemp()
with tempfile.TemporaryDirectory() as tmpdir:

def write_node(node):
df = feed.get(node)
Expand All @@ -54,7 +53,5 @@ def write_node(node):
outpath, _ = os.path.splitext(outpath)

outpath = shutil.make_archive(outpath, "zip", tmpdir)
finally:
shutil.rmtree(tmpdir)

return outpath
6 changes: 3 additions & 3 deletions tests/test_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@


def test_invalid_source():
with pytest.raises(ValueError, message="Invalid source"):
with pytest.raises(ValueError, match=r"Invalid source"):
Feed(fixture("missing"))


def test_duplicate_files():
with pytest.raises(ValueError, message="More than one"):
with pytest.raises(ValueError, match=r"More than one"):
Feed(fixtures_dir)


Expand All @@ -26,7 +26,7 @@ def test_bad_edge_config():

feed = Feed(fixture("caltrain-2017-07-24"), config=config)

with pytest.raises(ValueError, message="Edge missing `dependencies` attribute"):
with pytest.raises(ValueError, match=r"Edge missing `dependencies` attribute"):
feed.stop_times


Expand Down
6 changes: 3 additions & 3 deletions tests/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@ def test_parse_date():


def test_parse_date_with_invalid_month():
with pytest.raises(ValueError, message="unconverted data remains: 01"):
with pytest.raises(ValueError, match=r"unconverted data remains: 01"):
parse_date("20991401")


def test_parse_date_with_invalid_day():
with pytest.raises(ValueError, message="unconverted data remains: 3"):
with pytest.raises(ValueError, match=r"unconverted data remains: 3"):
parse_date("20990133")


Expand All @@ -39,7 +39,7 @@ def test_parse_time():


def test_parse_time_with_invalid_input():
with pytest.raises(ValueError, message="invalid literal for int()"):
with pytest.raises(ValueError, match=r"invalid literal for int()"):
parse_time("10:15:00am")


Expand Down
6 changes: 3 additions & 3 deletions tests/test_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def test_load_geo_feed_empty():


def test_missing_dir():
with pytest.raises(ValueError, message="File or path not found"):
with pytest.raises(ValueError, match=r"File or path not found"):
ptg.load_feed(fixture("missing"))


Expand All @@ -65,13 +65,13 @@ def test_config_must_be_dag():
config.add_edge("trips.txt", "routes.txt")

path = fixture("amazon-2017-08-06")
with pytest.raises(ValueError, message="Config must be a DAG"):
with pytest.raises(ValueError, match=r"Config must be a DAG"):
ptg.load_feed(path, config=config)


def test_no_service():
path = fixture("empty")
with pytest.raises(AssertionError, message="No service"):
with pytest.raises(AssertionError, match=r"No service"):
ptg.read_service_ids_by_date(path)


Expand Down
21 changes: 20 additions & 1 deletion tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
import io
import networkx as nx
import pytest

import numpy as np
import pandas as pd
from partridge.utilities import setwrap, remove_node_attributes, empty_df
from partridge.utilities import (
detect_encoding,
empty_df,
remove_node_attributes,
setwrap,
)


def test_setwrap():
Expand Down Expand Up @@ -41,3 +48,15 @@ def test_empty_df():
)

assert actual.equals(expected)


@pytest.mark.parametrize(
"test_string,encoding",
[
(b"abcde", "utf-8"), # straight up ascii is a subset of unicode
(b"Eyjafjallaj\xc3\xb6kull", "utf-8"), # actual unicode
(b"\xC4pple", "ISO-8859-1"), # non-unicode, ISO characterset
],
)
def test_detect_encoding(test_string, encoding):
assert detect_encoding(io.BytesIO(test_string)) == encoding
13 changes: 2 additions & 11 deletions tests/test_writers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import shutil
import tempfile

import partridge as ptg
Expand Down Expand Up @@ -29,8 +28,7 @@ def test_extract_agencies(path):
assert len(trip_ids)
assert len(stop_ids)

try:
tmpdir = tempfile.mkdtemp()
with tempfile.TemporaryDirectory() as tmpdir:
outfile = os.path.join(tmpdir, "test.zip")

result = ptg.extract_feed(
Expand Down Expand Up @@ -58,9 +56,6 @@ def test_extract_agencies(path):
new_df = new_fd.get(node)
assert set(original_df.columns) == set(new_df.columns)

finally:
shutil.rmtree(tmpdir)


@pytest.mark.parametrize(
"path", [zip_file("seattle-area-2017-11-16"), fixture("seattle-area-2017-11-16")]
Expand All @@ -83,8 +78,7 @@ def test_extract_routes(path):
assert len(trip_ids)
assert len(stop_ids)

try:
tmpdir = tempfile.mkdtemp()
with tempfile.TemporaryDirectory() as tmpdir:
outfile = os.path.join(tmpdir, "test.zip")

result = ptg.extract_feed(path, outfile, {"trips.txt": {"route_id": route_ids}})
Expand All @@ -109,6 +103,3 @@ def test_extract_routes(path):
original_df = fd.get(node)
new_df = new_fd.get(node)
assert set(original_df.columns) == set(new_df.columns)

finally:
shutil.rmtree(tmpdir)

0 comments on commit 4c64963

Please sign in to comment.