diff --git a/plotdigitizer/common.py b/plotdigitizer/common.py index bfee90e..f0a34c2 100644 --- a/plotdigitizer/common.py +++ b/plotdigitizer/common.py @@ -18,6 +18,7 @@ img_ = np.zeros((1, 1)) + def cache() -> Path: c = Path(tempfile.gettempdir()) / "plotdigitizer" c.mkdir(parents=True, exist_ok=True) diff --git a/plotdigitizer/image.py b/plotdigitizer/image.py index 1adc587..44566b6 100644 --- a/plotdigitizer/image.py +++ b/plotdigitizer/image.py @@ -12,12 +12,26 @@ from plotdigitizer import common from plotdigitizer import geometry from plotdigitizer import grid +from plotdigitizer import plot from plotdigitizer.trajectory import find_trajectory + +def click_points(event, x, y, _flags, params): + """callback for opencv image""" + assert common.img_ is not None, "No data set" + # Function to record the clicks. + YROWS = common.img_.shape[0] + if event == cv.EVENT_LBUTTONDOWN: + logger.info(f"You clicked on {(x, YROWS-y)}") + common.locations_.append(geometry.Point(x, YROWS - y)) + + class Figure: - def __init__(self, path: Path): + def __init__(self, path: Path, coordinates: T.List[str], indices: T.List[str]): assert path.exists(), f"{path} does not exists." logger.info(f"Reading {path}") + self.indices = list_to_points(indices) + self.coordinates = list_to_points(coordinates) self.path = path self.orignal = cv.imread(self.path) self.imgs = [("orig-gray-normalized", normalize(cv.imread(self.path, 0)))] @@ -37,11 +51,30 @@ def trajectories(self): def extract_trajectories(self): logger.info(f"Extracting trajectories from {infile}") + def map_axis(self): + logger.info("Mapping axis...") + logger.debug( + f"data points {self.coordinates} → location on image {self.indices}" + ) + + if len(self.coordinates) != len(self.indices): + logger.warning( + "Either the location of data-points on the image is not specified or their numbers don't" + " match with given datapoints. Asking user to fill the missing information..." + ) + + # next function uses callback. Needs a global variable to collect + # data. + common.locations_ = self.indices + self.indices = ask_user_to_locate_points(self.coordinates, self._last()) + assert len(self.coordinates) == len(self.indices) + def _last(self): return self.imgs[-1][1] def _append(self, operation: str, img): - self.imgs.append((operation, img)) + self.imgs.append((operation, img)) + def process_image(img, cache_key: T.Optional[str] = None): global params_ @@ -157,3 +190,27 @@ def save_img_in_cache( def normalize(img): """normalize image to 0, 255""" return np.interp(img, (img.min(), img.max()), (0, 255)).astype(np.uint8) + + +def list_to_points(points) -> T.List[geometry.Point]: + ps = [geometry.Point.fromCSV(x) for x in points] + return ps + + +def ask_user_to_locate_points(points, img) -> list: + """Ask user to map axis. Callback function save selected points in + common.locations_""" + cv.namedWindow(common.WindowName_) + cv.setMouseCallback(common.WindowName_, click_points) + while len(common.locations_) < len(points): + i = len(common.locations_) + p = points[i] + pLeft = len(points) - len(common.locations_) + plot.show_frame(img, "Please click on %s (%d left)" % (p, pLeft)) + if len(common.locations_) == len(points): + break + key = cv.waitKey(1) & 0xFF + if key == "q": + break + logger.info("You clicked %s" % common.locations_) + return common.locations_ diff --git a/plotdigitizer/plot.py b/plotdigitizer/plot.py new file mode 100644 index 0000000..2d1f1c6 --- /dev/null +++ b/plotdigitizer/plot.py @@ -0,0 +1,50 @@ +# helper function for plotting. + +import typing as T +from pathlib import Path +import matplotlib.pyplot as plt +import cv2 as cv +import numpy as np +from loguru import logger + +from plotdigitizer import common + + +def show_frame(img, msg="MSG: "): + msgImg = np.zeros(shape=(50, img.shape[1])) + cv.putText(msgImg, msg, (1, 40), 0, 0.5, 255) + newImg = np.vstack((img, msgImg.astype(np.uint8))) + cv.imshow(common.WindowName_, newImg) + + +def plot_traj(traj, img, outfile: T.Optional[Path] = None): + global locations_ + import matplotlib.pyplot as plt + + x, y = zip(*traj) + plt.figure() + plt.subplot(211) + + for p in common.locations_: + csize = img.shape[0] // 40 + cv.circle( + img, + (int(p.x), int(img.shape[0] - p.y)), + int(csize), + (128, 128, 128), + -1, + ) + + plt.imshow(img, interpolation="none", cmap="gray") + plt.axis(False) + plt.title("Original") + plt.subplot(212) + plt.title("Reconstructed") + plt.plot(x, y) + plt.tight_layout() + if not str(outfile): + plt.show() + else: + plt.savefig(outfile) + logger.info(f"Saved to {outfile}") + plt.close() diff --git a/plotdigitizer/plotdigitizer.py b/plotdigitizer/plotdigitizer.py index ac26c5a..b7c2e4e 100644 --- a/plotdigitizer/plotdigitizer.py +++ b/plotdigitizer/plotdigitizer.py @@ -4,14 +4,12 @@ import typing as T from pathlib import Path -import cv2 as cv - -import numpy as np import typer from typing_extensions import Annotated from plotdigitizer import grid from plotdigitizer import image +from plotdigitizer import plot from plotdigitizer import geometry from plotdigitizer import common @@ -22,72 +20,6 @@ app = typer.Typer() -def plot_traj(traj, outfile: Path): - global locations_ - import matplotlib.pyplot as plt - - x, y = zip(*traj) - plt.figure() - plt.subplot(211) - - for p in common.locations_: - csize = common.img_.shape[0] // 40 - cv.circle( - common.img_, (int(p.x), int(common.img_.shape[0] - p.y)), int(csize), (128, 128, 128), -1 - ) - - plt.imshow(common.img_, interpolation="none", cmap="gray") - plt.axis(False) - plt.title("Original") - plt.subplot(212) - plt.title("Reconstructed") - plt.plot(x, y) - plt.tight_layout() - if not str(outfile): - plt.show() - else: - plt.savefig(outfile) - logger.info(f"Saved to {outfile}") - plt.close() - - -def click_points(event, x, y, _flags, params): - assert common.img_ is not None, "No data set" - # Function to record the clicks. - YROWS = common.img_.shape[0] - if event == cv.EVENT_LBUTTONDOWN: - logger.info(f"You clicked on {(x, YROWS-y)}") - common.locations_.append(geometry.Point(x, YROWS - y)) - - -def show_frame(img, msg="MSG: "): - msgImg = np.zeros(shape=(50, img.shape[1])) - cv.putText(msgImg, msg, (1, 40), 0, 0.5, 255) - newImg = np.vstack((img, msgImg.astype(np.uint8))) - cv.imshow(common.WindowName_, newImg) - - -def ask_user_to_locate_points(points, img): - cv.namedWindow(common.WindowName_) - cv.setMouseCallback(common.WindowName_, click_points) - while len(common.locations_) < len(points): - i = len(common.locations_) - p = points[i] - pLeft = len(points) - len(common.locations_) - show_frame(img, "Please click on %s (%d left)" % (p, pLeft)) - if len(common.locations_) == len(points): - break - key = cv.waitKey(1) & 0xFF - if key == "q": - break - logger.info("You clicked %s" % common.locations_) - - -def list_to_points(points) -> T.List[geometry.Point]: - ps = [geometry.Point.fromCSV(x) for x in points] - return ps - - @app.command() def digitize_plot( infile: Path, @@ -122,27 +54,19 @@ def digitize_plot( ), ] = None, ): - figure = image.Figure(infile) + figure = image.Figure(infile, data_point, location) # remove grids. figure.remove_grid() - image.save_img_in_cache(common.img_, infile.name) - - common.points_ = list_to_points(data_point) - common.locations_ = list_to_points(location) - logger.debug(f"data points {data_point} → location on image {location}") - if len(common.locations_) != len(common.points_): - logger.warning( - "Either the location of data-points are not specified or their numbers don't" - " match with given datapoints. Asking user..." - ) - ask_user_to_locate_points(common.points_, common.img_) + # map the axis + figure.map_axis() + # compute trajectories traj = figure.trajectories() if plot_file is not None: - plot_traj(traj, plot_file) + plot.plot_traj(traj, figure._last(), plot_file) outfile = output or f"{infile}.traj.csv" with open(outfile, "w") as f: