"""Build a clean, normalized dataset from raw seed CSV files.
This script is intentionally small and opinionated for students.
It reads raw seed CSV files and outputs normalized, validated CSVs plus a quality report.
Input files (default: ml/):
- raw_stations.csv
- raw_places.csv
- raw_edges.csv
Output files (default: ml/out/):
- stations.csv
- places.csv
- edges.csv
- quality_report.json
"""
from __future__ import annotations
import argparse
import csv
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
def _read_csv (path: Path) -> List[Dict[ str , str ]]:
with path.open( "r" , encoding = "utf-8" , newline = "" ) as f:
reader = csv.DictReader(f)
return [ dict (row) for row in reader]
def _write_csv (path: Path, rows: Iterable[Dict[ str , Any]], fieldnames: List[ str ]) -> None :
path.parent.mkdir( parents = True , exist_ok = True )
with path.open( "w" , encoding = "utf-8" , newline = "" ) as f:
writer = csv.DictWriter(f, fieldnames = fieldnames)
writer.writeheader()
for row in rows:
writer.writerow({k: row.get(k, "" ) for k in fieldnames})
def _to_float (v: Any) -> Optional[ float ]:
if v is None :
return None
s = str (v).strip()
if s == "" :
return None
try :
return float (s)
except ValueError :
return None
def _to_int (v: Any) -> Optional[ int ]:
if v is None :
return None
s = str (v).strip()
if s == "" :
return None
try :
return int ( float (s))
except ValueError :
return None
def _norm_id (v: Any) -> str :
return str (v or "" ).strip()
def _norm_text (v: Any) -> str :
return " " .join( str (v or "" ).strip().split())
def _norm_category (v: Any) -> str :
s = _norm_text(v).lower()
s = s.replace( " " , "_" )
return s
@dataclass
class QualityReport :
input_rows: Dict[ str , int ]
output_rows: Dict[ str , int ]
dropped_rows: Dict[ str , int ]
duplicate_keys: Dict[ str , int ]
invalid_values: Dict[ str , int ]
warnings: List[ str ]
def to_dict (self) -> Dict[ str , Any]:
return {
"input_rows" : self .input_rows,
"output_rows" : self .output_rows,
"dropped_rows" : self .dropped_rows,
"duplicate_keys" : self .duplicate_keys,
"invalid_values" : self .invalid_values,
"warnings" : self .warnings,
}
def build_dataset (raw_dir: Path, out_dir: Path) -> QualityReport:
stations_raw = _read_csv(raw_dir / "raw_stations.csv" )
places_raw = _read_csv(raw_dir / "raw_places.csv" )
edges_raw = _read_csv(raw_dir / "raw_edges.csv" )
input_rows = {
"stations" : len (stations_raw),
"places" : len (places_raw),
"edges" : len (edges_raw),
}
warnings: List[ str ] = []
invalid_values = {
"stations_latlon_missing" : 0 ,
"places_avg_rating_invalid" : 0 ,
"places_rating_count_invalid" : 0 ,
"edges_eta_invalid" : 0 ,
}
stations: List[Dict[ str , Any]] = []
station_seen: set[ str ] = set ()
station_dupes = 0
station_dropped = 0
for r in stations_raw:
station_id = _norm_id(r.get( "station_id" ))
if not station_id:
station_dropped += 1
continue
if station_id in station_seen:
station_dupes += 1
continue
station_seen.add(station_id)
lat = _to_float(r.get( "lat" ))
lon = _to_float(r.get( "lon" ))
if lat is None or lon is None :
invalid_values[ "stations_latlon_missing" ] += 1
stations.append(
{
"station_id" : station_id,
"station_name" : _norm_text(r.get( "station_name" )),
"line_name" : _norm_text(r.get( "line_name" )),
"lat" : "" if lat is None else round (lat, 6 ),
"lon" : "" if lon is None else round (lon, 6 ),
}
)
station_index = {s[ "station_id" ] for s in stations}
places: List[Dict[ str , Any]] = []
place_seen: set[ str ] = set ()
place_dupes = 0
place_dropped = 0
# Seed de-dup key: (name, station, category)
place_key_seen: set[Tuple[ str , str , str ]] = set ()
place_key_dupes = 0
for r in places_raw:
place_id = _norm_id(r.get( "place_id" ))
place_name = _norm_text(r.get( "place_name" ))
station_id = _norm_id(r.get( "station_id" ))
category = _norm_category(r.get( "category" ))
if not place_id or not place_name:
place_dropped += 1
continue
if place_id in place_seen:
place_dupes += 1
continue
place_seen.add(place_id)
if not station_id or station_id not in station_index:
place_dropped += 1
continue
key = (place_name.lower(), station_id, category)
if key in place_key_seen:
place_key_dupes += 1
continue
place_key_seen.add(key)
rating_count = _to_int(r.get( "rating_count" ))
if rating_count is None or rating_count < 0 :
invalid_values[ "places_rating_count_invalid" ] += 1
rating_count = max ( 0 , rating_count or 0 )
avg_rating = _to_float(r.get( "avg_rating" ))
if avg_rating is None or not ( 0.0 <= avg_rating <= 5.0 ):
invalid_values[ "places_avg_rating_invalid" ] += 1
avg_rating = None
price_level = _to_int(r.get( "price_level" ))
if price_level is None or not ( 1 <= price_level <= 4 ):
price_level = None
places.append(
{
"place_id" : place_id,
"place_name" : place_name,
"category" : category,
"station_id" : station_id,
"rating_count" : rating_count,
"avg_rating" : "" if avg_rating is None else round (avg_rating, 2 ),
"price_level" : "" if price_level is None else price_level,
}
)
edges: List[Dict[ str , Any]] = []
edge_dropped = 0
edge_dupes = 0
edge_seen: set[Tuple[ str , str , str ]] = set ()
for r in edges_raw:
a = _norm_id(r.get( "from_station_id" ))
b = _norm_id(r.get( "to_station_id" ))
mode = _norm_category(r.get( "mode" ))
eta = _to_int(r.get( "eta_min" ))
if not a or not b or not mode:
edge_dropped += 1
continue
if a not in station_index or b not in station_index:
edge_dropped += 1
continue
if eta is None or eta <= 0 or eta > 300 :
invalid_values[ "edges_eta_invalid" ] += 1
edge_dropped += 1
continue
key = (a, b, mode)
if key in edge_seen:
edge_dupes += 1
continue
edge_seen.add(key)
edges.append({ "from_station_id" : a, "to_station_id" : b, "mode" : mode, "eta_min" : eta})
out_dir.mkdir( parents = True , exist_ok = True )
_write_csv(out_dir / "stations.csv" , stations, [ "station_id" , "station_name" , "line_name" , "lat" , "lon" ])
_write_csv(
out_dir / "places.csv" ,
places,
[ "place_id" , "place_name" , "category" , "station_id" , "rating_count" , "avg_rating" , "price_level" ],
)
_write_csv(out_dir / "edges.csv" , edges, [ "from_station_id" , "to_station_id" , "mode" , "eta_min" ])
report = QualityReport(
input_rows = input_rows,
output_rows = { "stations" : len (stations), "places" : len (places), "edges" : len (edges)},
dropped_rows = { "stations" : station_dropped, "places" : place_dropped, "edges" : edge_dropped},
duplicate_keys = {
"stations_station_id" : station_dupes,
"places_place_id" : place_dupes,
"places_name_station_category" : place_key_dupes,
"edges_from_to_mode" : edge_dupes,
},
invalid_values = invalid_values,
warnings = warnings,
)
(out_dir / "quality_report.json" ).write_text(
json.dumps(report.to_dict(), ensure_ascii = False , indent = 2 ) + " \n " ,
encoding = "utf-8" ,
)
return report
def main () -> int :
parser = argparse.ArgumentParser( description = "Build a clean dataset from raw seed CSV files." )
parser.add_argument( "--raw-dir" , default = "ml" , help = "Directory containing raw_*.csv" )
parser.add_argument( "--out-dir" , default = "ml/out" , help = "Output directory" )
args = parser.parse_args()
report = build_dataset(Path(args.raw_dir), Path(args.out_dir))
print ( "Dataset build finished." )
print (json.dumps(report.to_dict(), ensure_ascii = False , indent = 2 ))
return 0
if __name__ == "__main__" :
raise SystemExit (main())