From 42568f99119d35d2862d8656878e3332429bf25e Mon Sep 17 00:00:00 2001 From: jameshod5 Date: Wed, 1 May 2024 13:58:46 +0100 Subject: [PATCH] ruff fixes --- src/api/create.py | 6 ++---- src/api/crud.py | 8 ++++---- src/api/main.py | 30 +++++------------------------- src/api/models.py | 3 +-- 4 files changed, 12 insertions(+), 35 deletions(-) diff --git a/src/api/create.py b/src/api/create.py index a4b576c..efacaef 100644 --- a/src/api/create.py +++ b/src/api/create.py @@ -1,4 +1,3 @@ -import json import logging import uuid from enum import Enum @@ -9,12 +8,11 @@ import dask.dataframe as dd import numpy as np import pandas as pd -from sqlalchemy import MetaData, create_engine, dialects, select, types +from sqlalchemy import MetaData, create_engine, select from sqlalchemy_utils.functions import create_database, database_exists, drop_database from sqlmodel import SQLModel from tqdm import tqdm -from . import models from .environment import SQLALCHEMY_DATABASE_URL, SQLALCHEMY_DEBUG logging.basicConfig(level=logging.INFO) @@ -113,7 +111,7 @@ def create_shots(self, data_path: Path): shot_metadata = shot_metadata.drop(["scenario_id", "reference_id"], axis=1) shot_metadata["uuid"] = shot_metadata.index.map(get_dataset_uuid) shot_metadata["url"] = ( - f"s3://mast/shots/" + "s3://mast/shots/" + shot_metadata["campaign"] + "/" + shot_metadata.index.astype(str) diff --git a/src/api/crud.py b/src/api/crud.py index e8b2278..14a8cc5 100644 --- a/src/api/crud.py +++ b/src/api/crud.py @@ -14,7 +14,7 @@ from . import models from .database import engine -from .utils import aggregate_map, comparator_map, get_fields_non_optional +from .utils import aggregate_map, comparator_map COMPARATOR_NAMES_DESCRIPTION = ", ".join( ["$" + name + ":" for name in comparator_map.keys()] @@ -294,9 +294,9 @@ def get_table_as_dataframe(query, name: str, ext: str = "parquet"): df = pd.read_sql(query.statement, con=engine.connect()) columns = df.columns - for column in columns: - if df[column].dtype == uuid.UUID: - df[column] = df[column].astype(str) + for column_item in columns: + if df[column_item].dtype == uuid.UUID: + df[column_item] = df[column].astype(str) stream = io.BytesIO() if media_type == "binary" else io.StringIO() DF_EXPORT_FUNCS[ext](df, stream) diff --git a/src/api/main.py b/src/api/main.py index 4d6e66a..00449c0 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -1,38 +1,18 @@ -import io -import json import os import uuid -from typing import Annotated, List, Optional, get_type_hints +from typing import List, Optional -import h5py -import ndjson -import pandas as pd import sqlmodel -from fastapi import Depends, FastAPI, HTTPException, Query, Request, Response -from fastapi.encoders import jsonable_encoder -from fastapi.responses import ( - FileResponse, - HTMLResponse, - JSONResponse, - RedirectResponse, - StreamingResponse, -) +from fastapi import Depends, FastAPI, Query, Request, Response from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates -from fastapi_pagination import Page, add_pagination -from fastapi_pagination.ext.sqlalchemy import paginate -from pydantic import BaseModel, Field, create_model from sqlalchemy.orm import Session from strawberry.asgi import GraphQL -from strawberry.fastapi import GraphQLRouter from strawberry.http import GraphQLHTTPResponse from strawberry.types import ExecutionResult -from . import crud, graphql, models, utils -from .database import SessionLocal, engine, get_db -from .page import MetadataPage -from .types import FileType -from .utils import InputParams +from . import crud, graphql, models +from .database import get_db templates = Jinja2Templates(directory="src/api/templates") @@ -370,7 +350,7 @@ def get_sources( "/json/sources/{name}", description="Get information about a single signal", ) -def get_signal(db: Session = Depends(get_db), name: str = None) -> models.SourceModel: +def get_source(db: Session = Depends(get_db), name: str = None) -> models.SourceModel: source = crud.get_source(db, name) source = db.execute(source).one()[0] return source diff --git a/src/api/models.py b/src/api/models.py index 7597829..44276ba 100644 --- a/src/api/models.py +++ b/src/api/models.py @@ -3,8 +3,7 @@ from typing import Dict, List, Optional from sqlalchemy import ARRAY, Column, Enum, Integer, Text -from sqlalchemy.dialects.postgresql import JSONB, UUID -from sqlalchemy.orm import relationship +from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field, Relationship, SQLModel from .types import (