Skip to content
Snippets Groups Projects
Verified Commit 588cf8e7 authored by Marco Aceti's avatar Marco Aceti
Browse files

Parallelize dataset loading

parent b790d25e
No related branches found
No related tags found
No related merge requests found
......@@ -6,8 +6,8 @@ from datetime import datetime
import pandas as pd
from dateparser import parse
from joblib import Parallel, delayed
from pandas.core.groupby.generic import DataFrameGroupBy
from tqdm import tqdm
from src.analysis import groupby, stat, trajectories_map
from src.analysis.filter import date_filter, railway_company_filter
......@@ -71,6 +71,14 @@ def register_args(parser: argparse.ArgumentParser):
)
@delayed
def _load_train_dataset(train_csv: str) -> pd.DataFrame:
path = pathlib.Path(train_csv)
train_df: pd.DataFrame = read_train_csv(pathlib.Path(train_csv))
logging.debug(f"Loaded {len(train_df)} data points @ {path}")
return train_df
def main(args: argparse.Namespace):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
......@@ -88,15 +96,12 @@ def main(args: argparse.Namespace):
# Load dataset
df: pd.DataFrame | DataFrameGroupBy = pd.DataFrame()
logging.info("Loading datasets...")
for train_csv in (
tqdm(args.trains_csv)
if logging.root.getEffectiveLevel() > logging.DEBUG
else args.trains_csv
for train_df in Parallel(n_jobs=-1, verbose=5)(
_load_train_dataset(train_csv) for train_csv in args.trains_csv # type: ignore
):
path = pathlib.Path(train_csv)
train_df: pd.DataFrame = read_train_csv(pathlib.Path(train_csv))
df = pd.concat([df, train_df], axis=0)
logging.debug(f"Loaded {len(train_df)} data points @ {path}")
df.reset_index(drop=True, inplace=True)
stations: pd.DataFrame = read_station_csv(args.station_csv)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment