Coverage for src/sl_transit_repl/main.py: 84%

369 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-22 22:06 +0200

1#!/usr/bin/env python3 

2 

3import argparse 

4import json 

5import re 

6import sys 

7import time 

8from collections import defaultdict 

9from datetime import datetime 

10from pathlib import Path 

11from typing import Any 

12 

13import requests 

14import unidecode 

15from prompt_toolkit import prompt 

16from prompt_toolkit.completion import WordCompleter 

17from prompt_toolkit.history import FileHistory 

18from rich.console import Console 

19from rich.table import Table 

20 

21 

22class SLTransitREPL: 

23 """Self-contained class for SL Transit departure queries with interactive REPL interface.""" 

24 

25 # Class constants 

26 BASE_URL = "https://transport.integration.sl.se/v1" 

27 TRANSPORT_MODES = ["BUS", "TRAM", "METRO", "TRAIN", "FERRY", "SHIP", "TAXI"] 

28 VALID_PARAMS = { 

29 "site": r"\d+", 

30 "transport": f"({'|'.join(TRANSPORT_MODES)})", 

31 "line": r"\d+", 

32 "direction": r"[12]", 

33 "forecast": r"\d+", 

34 "show_numbers": r"(?:true|false|TRUE|FALSE)", 

35 "debug": r"(?:true|false|TRUE|FALSE)", 

36 } 

37 DEFAULT_FORECAST = 60 

38 

39 # Color configuration for transport modes 

40 TRANSPORT_COLORS = { 

41 "BUS": "red3", # #BF616A variant 

42 "TRAM": "orange3", 

43 "METRO": "blue3", # #007FC8 variant 

44 "TRAIN": "magenta", 

45 "FERRY": "purple", 

46 "SHIP": "dark_green", 

47 "TAXI": "yellow", 

48 } 

49 

50 # Line colors configuration from Stockholm transit system (terminal-friendly approximations) 

51 # the names are for reference only, not used in code for anything 

52 # full subway map link on this SL page: https://sl.se/reseplanering/kartor/spartrafikkartor 

53 LINE_COLORS = defaultdict( 

54 lambda: {"color": "white", "name": "Unknown Line"}, 

55 { 

56 "12": {"color": "green4", "name": "Nockebybanan"}, 

57 "21": {"color": "purple", "name": "Lidingöbanan"}, 

58 "25": {"color": "light_sea_green", "name": "Saltsjöbanan"}, 

59 "26": {"color": "light_sea_green", "name": "Saltsjöbanan"}, 

60 "27": {"color": "purple3", "name": "Roslagsbanan"}, 

61 "28": {"color": "purple3", "name": "Roslagsbanan"}, 

62 "29": {"color": "purple3", "name": "Roslagsbanan"}, 

63 "30": {"color": "orange3", "name": "Tvärbanan"}, 

64 "31": {"color": "orange3", "name": "Tvärbanan"}, 

65 "7": {"color": "grey58", "name": "Spårväg City"}, 

66 "10": {"color": "blue", "name": "Blue Line - Hjulsta to Kungsträdgården"}, 

67 "11": {"color": "blue", "name": "Blue Line - Akalla to Kungsträdgården"}, 

68 "13": {"color": "red", "name": "Red Line - Norsborg to Ropsten"}, 

69 "14": {"color": "red", "name": "Red Line - Fruängen to Mörby centrum"}, 

70 "17": {"color": "green", "name": "Green Line - Åkeshov to Skarpnäck"}, 

71 "18": {"color": "green", "name": "Green Line - Alvik to Farsta strand"}, 

72 "19": { 

73 "color": "green", 

74 "name": "Green Line - Hässelby strand to Hagsätra", 

75 }, 

76 "40": { 

77 "color": TRANSPORT_COLORS["TRAIN"], 

78 "name": "Pendeltåg - Uppsala C to Södertälje centrum", 

79 }, 

80 "41": { 

81 "color": TRANSPORT_COLORS["TRAIN"], 

82 "name": "Pendeltåg - Märsta to Södertälje centrum", 

83 }, 

84 "42X": { 

85 "color": TRANSPORT_COLORS["TRAIN"], 

86 "name": "Pendeltåg - Märsta to Nynäshamn", 

87 }, 

88 "43": { 

89 "color": TRANSPORT_COLORS["TRAIN"], 

90 "name": "Pendeltåg - Bålsta to Nynäshamn", 

91 }, 

92 "44": { 

93 "color": TRANSPORT_COLORS["TRAIN"], 

94 "name": "Pendeltåg - Kallhäll to Tumba", 

95 }, 

96 "48": { 

97 "color": TRANSPORT_COLORS["TRAIN"], 

98 "name": "Pendeltåg - Södertälje centrum to Gnesta", 

99 }, 

100 }, 

101 ) 

102 

103 # Time-based color thresholds (in minutes) 

104 TIME_WARNING_THRESHOLD = 15 # <15min gets green color 

105 TIME_DELAY_THRESHOLD = 5 # >5min difference gets red color 

106 

107 def __init__(self, app_dir: str | Path | None = None): 

108 """Initialize the SL Transit REPL. 

109 

110 Args: 

111 app_dir: Optional path to application directory. If None, uses '~/.sl_transit_repl'. 

112 

113 The application directory will contain the following files: 

114 - cache/sites.json: Cached sites data 

115 - .repl_history: History file for the REPL 

116 """ 

117 # Set up app directory 

118 if app_dir: 

119 self.app_dir = Path(app_dir).expanduser() 

120 else: 

121 self.app_dir = Path.home() / ".sl_transit_repl" 

122 

123 self.app_dir.mkdir(exist_ok=True) 

124 

125 # Set up paths within app directory 

126 self.cache_dir = self.app_dir / "cache" 

127 self.sites_json = self.cache_dir / "sites.json" 

128 self.history_file = str(self.app_dir / ".repl_history") 

129 

130 # Ensure cache directory exists 

131 self.cache_dir.mkdir(exist_ok=True) 

132 

133 # Initialize console and load sites 

134 self.console = Console() 

135 self.sites = self._load_sites() 

136 

137 # Build search indices for site lookup functionality 

138 self._build_search_indices() 

139 

140 def _fetch_sites(self) -> dict[str, dict[str, Any]]: 

141 """Fetch all sites from the API and return as a dictionary keyed by site ID.""" 

142 try: 

143 response = requests.get(f"{self.BASE_URL}/sites") 

144 response.raise_for_status() 

145 sites = response.json() 

146 

147 # Transform the list into a dictionary keyed by site ID 

148 return {str(site["id"]): site for site in sites} 

149 except requests.RequestException as e: 

150 self.console.print(f"[red]Error fetching sites: {str(e)}[/red]") 

151 return {} 

152 

153 def _load_sites(self) -> dict[str, dict[str, Any]]: 

154 """Load sites dictionary from JSON file or fetch from API if needed.""" 

155 sites = {} 

156 need_fetch = True 

157 

158 try: 

159 if self.sites_json.exists(): 

160 with self.sites_json.open("r") as f: 

161 cache_data = json.load(f) 

162 

163 # Handle new format with metadata 

164 if ( 

165 isinstance(cache_data, dict) 

166 and "metadata" in cache_data 

167 and "sites" in cache_data 

168 ): 

169 sites = cache_data["sites"] 

170 fetch_date_str = cache_data["metadata"].get("fetch_date") 

171 if fetch_date_str: 

172 # Store metadata for potential future use 

173 self._cache_metadata = cache_data["metadata"] 

174 need_fetch = self._is_cache_stale() 

175 else: 

176 # No fetch date, treat as stale 

177 need_fetch = True 

178 # Don't set need_fetch = False here if we have metadata 

179 except (json.JSONDecodeError, OSError) as e: 

180 self.console.print( 

181 f"[yellow]Warning: Could not read sites cache: {e}[/yellow]" 

182 ) 

183 

184 if need_fetch: 

185 sites = self._fetch_sites() 

186 self._save_sites(sites) 

187 

188 return sites 

189 

190 def _save_sites(self, sites: dict[str, dict[str, Any]]) -> None: 

191 """Save sites dictionary to JSON file with metadata including fetch timestamp.""" 

192 try: 

193 cache_data = { 

194 "metadata": { 

195 "fetch_date": datetime.now().isoformat(), 

196 "version": "1.0", 

197 }, 

198 "sites": sites, 

199 } 

200 with self.sites_json.open("w", encoding="utf-8") as f: 

201 json.dump(cache_data, f, indent=2, ensure_ascii=False) 

202 except OSError as e: 

203 self.console.print( 

204 f"[yellow]Warning: Could not save sites cache: {e}[/yellow]" 

205 ) 

206 

207 def _is_cache_stale(self, max_age_hours: int = 24) -> bool: 

208 """Check if the cached data is stale based on fetch date. 

209 

210 Args: 

211 max_age_hours: Maximum age of cache in hours before considering it stale 

212 

213 Returns: 

214 True if cache is stale or no metadata available, False otherwise 

215 """ 

216 if not hasattr(self, "_cache_metadata") or not self._cache_metadata: 

217 return True 

218 

219 fetch_date_str = self._cache_metadata.get("fetch_date") 

220 if not fetch_date_str: 

221 return True 

222 

223 try: 

224 fetch_date = datetime.fromisoformat(fetch_date_str) 

225 age = ( 

226 datetime.now() - fetch_date 

227 ).total_seconds() / 3600 # Convert to hours 

228 return age > max_age_hours 

229 except (ValueError, TypeError): 

230 return True 

231 

232 def _get_site_info(self, site_id: str) -> dict[str, Any] | None: 

233 """Get site information and validate it exists.""" 

234 if site_id not in self.sites: 

235 self.console.print( 

236 f"[red]Error: Site ID {site_id} not found in known sites.[/red]" 

237 ) 

238 self.console.print( 

239 f"[yellow]Hint: Delete {self.sites_json} to refresh the sites cache.[/yellow]" 

240 ) 

241 return None 

242 return self.sites[site_id] 

243 

244 def _build_search_indices(self) -> None: 

245 """Build search indices for site lookup functionality.""" 

246 self.idx_by_id: dict[int, dict] = {} 

247 self.idx_by_name: list[ 

248 tuple[str, dict] 

249 ] = [] # List of (normalized_name, site_data) tuples 

250 

251 # Convert sites dict to list format like site_search.py expects 

252 sites_list = list(self.sites.values()) 

253 

254 for site in sites_list: 

255 # Index by ID 

256 self.idx_by_id[site["id"]] = site 

257 

258 # Create normalized name for fuzzy matching 

259 normalized_name = self._normalize_text(site["name"]) 

260 self.idx_by_name.append((normalized_name, site)) 

261 

262 # Also index aliases if they exist 

263 if "alias" in site: 

264 for alias in site["alias"]: 

265 normalized_alias = self._normalize_text(alias) 

266 self.idx_by_name.append((normalized_alias, site)) 

267 

268 def _normalize_text(self, text: str) -> str: 

269 """Normalize text by removing diacritics and converting to lowercase.""" 

270 return unidecode.unidecode(text.lower()) 

271 

272 def _find_site_by_id(self, site_id: int) -> dict | None: 

273 """Find a site by its ID.""" 

274 return self.idx_by_id.get(site_id) 

275 

276 def _find_sites_by_substring(self, substring: str) -> list[dict]: 

277 """Find sites where the name contains the given substring. 

278 Matches are diacritic-insensitive and case-insensitive.""" 

279 normalized_query = self._normalize_text(substring) 

280 results = [] 

281 seen = set() # To avoid duplicates 

282 

283 for norm_name, site in self.idx_by_name: 

284 if normalized_query in norm_name and site["id"] not in seen: 

285 results.append(site) 

286 seen.add(site["id"]) 

287 

288 return sorted(results, key=lambda x: x["name"]) 

289 

290 def _create_site_table(self, sites: list[dict] | dict, title: str) -> Table: 

291 """Create a rich table for displaying site results.""" 

292 if not isinstance(sites, list): 

293 sites = [sites] 

294 

295 table = Table(title=title, show_header=True) 

296 table.add_column("ID", justify="right", style="cyan") 

297 table.add_column("Name", style="green") 

298 table.add_column("Aliases", style="yellow") 

299 table.add_column("Abbreviation", style="blue") 

300 table.add_column("Coordinates", style="magenta") 

301 

302 for site in sites: 

303 if site: # Skip None results 

304 aliases = ", ".join(site.get("alias", [])) if "alias" in site else "" 

305 coords = f"{site['lat']:.4f}, {site['lon']:.4f}" 

306 table.add_row( 

307 str(site["id"]), 

308 site["name"], 

309 aliases, 

310 site.get("abbreviation", ""), 

311 coords, 

312 ) 

313 

314 return table 

315 

316 def _create_departure_table( 

317 self, departures: list, show_direction_numbers: bool = False 

318 ) -> Table: 

319 """Create a rich table for departures.""" 

320 table = Table(show_header=True, header_style="bold magenta") 

321 table.add_column("Line") 

322 table.add_column("Transport") 

323 table.add_column("Direction") 

324 table.add_column("Scheduled") 

325 table.add_column("Expected") 

326 table.add_column("Status") 

327 table.add_column("Platform") 

328 

329 for dep in departures: 

330 line = dep["line"].get("designation", str(dep["line"]["id"])) 

331 transport = dep["line"].get("transport_mode", "N/A") 

332 direction_code = dep.get("direction_code", "N/A") 

333 direction = dep.get("direction", "N/A") 

334 

335 # Format direction with number if requested 

336 if show_direction_numbers and direction_code != "N/A": 

337 direction = f"({direction_code}) {direction}" 

338 

339 scheduled_time = datetime.fromisoformat(dep["scheduled"]) 

340 expected_time = ( 

341 datetime.fromisoformat(dep["expected"]) 

342 if "expected" in dep 

343 else scheduled_time 

344 ) 

345 scheduled = scheduled_time.strftime("%H:%M") 

346 expected = expected_time.strftime("%H:%M") 

347 status = dep["state"] 

348 platform = dep["stop_point"].get("designation", "N/A") 

349 

350 # Apply colors 

351 # Line color (and direction gets same color) 

352 line_color = self.LINE_COLORS[line]["color"] 

353 colored_line = f"[{line_color}]{line}[/{line_color}]" 

354 colored_direction = f"[{line_color}]{direction}[/{line_color}]" 

355 

356 # Transport color 

357 # transport_color = self.TRANSPORT_COLORS.get(transport, "white") 

358 # colored_transport = f"[{transport_color}]{transport}[/{transport_color}]" 

359 

360 # Status color 

361 if status == "CANCELLED": 

362 colored_status = f"[red3]{status}[/red3]" 

363 elif status == "ATSTOP": 

364 colored_status = f"[bold]{status}[/bold]" 

365 elif status == "EXPECTED": 

366 colored_status = f"[dim]{status}[/dim]" 

367 else: 

368 colored_status = status 

369 

370 # Time colors - check if within warning threshold or delayed 

371 # Use timezone-aware datetime if the parsed times have timezone info 

372 if scheduled_time.tzinfo is not None: 

373 now = datetime.now(scheduled_time.tzinfo) 

374 else: 

375 now = datetime.now() 

376 

377 scheduled_diff = (scheduled_time - now).total_seconds() / 60 

378 expected_diff = (expected_time - now).total_seconds() / 60 

379 time_delay = abs((expected_time - scheduled_time).total_seconds() / 60) 

380 

381 # Color scheduled time if within 15 minutes 

382 if 0 <= scheduled_diff <= self.TIME_WARNING_THRESHOLD: 

383 colored_scheduled = f"[dark_cyan]{scheduled}[/dark_cyan]" 

384 else: 

385 colored_scheduled = f"[dim]{scheduled}[/dim]" 

386 

387 # Color expected time based on proximity and delay 

388 if 0 <= expected_diff <= self.TIME_WARNING_THRESHOLD: 

389 colored_expected = f"[dark_cyan]{expected}[/dark_cyan]" 

390 elif time_delay > self.TIME_DELAY_THRESHOLD: 

391 colored_expected = f"[red]{expected}[/red]" 

392 else: 

393 colored_expected = f"[dim]{expected}[/dim]" 

394 

395 table.add_row( 

396 colored_line, 

397 transport, # colored_transport seemed too distracting 

398 colored_direction, 

399 colored_scheduled, 

400 colored_expected, 

401 colored_status, 

402 platform, 

403 ) 

404 

405 return table 

406 

407 def _get_departures( 

408 self, 

409 site_id: int, 

410 params: dict[str, Any], 

411 show_direction_numbers: bool = False, 

412 debug: bool = False, 

413 site_info: dict[str, Any] | None = None, 

414 ) -> None: 

415 """Fetch and display departures for a given site ID.""" 

416 try: 

417 # Remove show_numbers from API params if present 

418 api_params = params.copy() 

419 api_params.pop("show_numbers", None) 

420 

421 # Extract headers from underscore-prefixed parameters 

422 headers = {} 

423 for param, value in list(api_params.items()): 

424 if param.startswith("_"): 

425 # Remove underscore prefix and use as header name 

426 header_name = param[1:].replace("_", "-") 

427 headers[header_name] = value 

428 api_params.pop(param) 

429 

430 if debug: 

431 self.console.print("\n[bold yellow]Request Headers:[/bold yellow]") 

432 for header, value in headers.items(): 

433 self.console.print(f"{header}: {value}") 

434 self.console.print("\n") 

435 

436 # cache busting 

437 api_params["_t"] = str( 

438 int(time.time() * 1000) 

439 ) # Current timestamp in milliseconds 

440 

441 response = requests.get( 

442 f"{self.BASE_URL}/sites/{site_id}/departures", 

443 params=api_params, 

444 headers=headers, 

445 ) 

446 

447 if debug: 

448 self.console.print("\n[bold yellow]Response Headers:[/bold yellow]") 

449 for header, value in response.headers.items(): 

450 self.console.print(f"{header}: {value}") 

451 self.console.print("\n") 

452 

453 data = response.json() 

454 

455 if not data.get("departures"): 

456 self.console.print( 

457 "[yellow]No departures found for the given criteria.[/yellow]" 

458 ) 

459 return 

460 

461 # Display site information 

462 if site_info: 

463 site_name = site_info["name"] 

464 self.console.print( 

465 f"\n[bold white]Site: {site_name} ({site_id})[/bold white]" 

466 ) 

467 if site_info.get("note"): 

468 self.console.print(f"[blue]{site_info['note']}[/blue]") 

469 else: 

470 self.console.print(f"\n[bold blue]Site ID: {site_id}[/bold blue]") 

471 

472 table = self._create_departure_table( 

473 data["departures"], show_direction_numbers 

474 ) 

475 self.console.print(table) 

476 

477 # Display any deviations 

478 if data.get("stop_deviations"): 

479 self.console.print("\n[bold red]Deviations:[/bold red]") 

480 for dev in data["stop_deviations"]: 

481 self.console.print(f"- {dev['message']}") 

482 

483 except requests.RequestException as e: 

484 self.console.print(f"[red]Error fetching departures: {str(e)}[/red]") 

485 

486 def _parse_query(self, query: str) -> tuple[bool, dict[str, str]]: 

487 """Parse the input query and return parameters.""" 

488 # Initialize parameters 

489 params = {} 

490 

491 # Split the query into parts 

492 parts = query.strip().split() 

493 

494 if not parts: 

495 return False, {"error": ""} 

496 

497 # Check for special commands first 

498 first_part = parts[0].lower() 

499 

500 if first_part == "help": 

501 return True, {"command": "help"} 

502 

503 if first_part.startswith("lookup:"): 

504 lookup_type = first_part.split(":", 1)[1].lower() 

505 if lookup_type not in ["id", "name"]: 

506 return False, { 

507 "error": f"Invalid lookup type: {lookup_type}. Must be 'id' or 'name'" 

508 } 

509 

510 if len(parts) < 2: 

511 return False, {"error": f"Lookup {lookup_type} requires a search term"} 

512 

513 search_term = " ".join(parts[1:]) # Join all remaining parts as search term 

514 

515 return True, { 

516 "command": "lookup", 

517 "lookup_type": lookup_type, 

518 "search_term": search_term, 

519 } 

520 

521 # If first part doesn't have a label, assume it's the site ID 

522 if ":" not in parts[0]: 

523 if not re.match(r"^\d+$", parts[0]): 

524 return False, {"error": f"Invalid site ID: {parts[0]}"} 

525 params["site"] = parts[0] 

526 parts = parts[1:] 

527 

528 # Process labeled parameters 

529 for part in parts: 

530 if ":" not in part: 

531 return False, { 

532 "error": f"Invalid parameter format: {part}. Must be param:value" 

533 } 

534 

535 param, value = part.split(":", 1) 

536 param = param.lower() 

537 

538 # Skip validation for underscore-prefixed parameters 

539 if param.startswith("_"): 

540 params[param] = value 

541 continue 

542 

543 if param in ("transport"): 

544 value = value.upper() 

545 else: 

546 value = value.lower() 

547 

548 # Validate parameter name 

549 if param not in self.VALID_PARAMS: 

550 return False, {"error": f"Invalid parameter name: {param}"} 

551 

552 # Validate parameter value 

553 if not re.match(f"^{self.VALID_PARAMS[param]}$", value): 

554 return False, {"error": f"Invalid value for {param}: {value}"} 

555 

556 params[param] = value 

557 

558 # Ensure site ID is provided for departure queries 

559 if "site" not in params: 

560 return False, {"error": "Site ID is required"} 

561 

562 return True, params 

563 

564 def _create_completer(self) -> WordCompleter: 

565 """Create a completer for parameters and values.""" 

566 words = [] 

567 

568 # Add special commands 

569 words.extend(["help", "lookup:id", "lookup:name"]) 

570 

571 # Add parameter names with colon 

572 for param in self.VALID_PARAMS: 

573 words.append(f"{param}:") 

574 # Add transport modes with prefix 

575 for mode in self.TRANSPORT_MODES: 

576 words.append(f"transport:{mode}") 

577 # Add directions with prefix 

578 words.extend(["direction:1", "direction:2"]) 

579 # Add show_numbers options 

580 words.extend(["show_numbers:true", "show_numbers:false"]) 

581 

582 return WordCompleter(words, ignore_case=True, pattern=re.compile(r"^|[^\w:]")) 

583 

584 def _show_help(self) -> None: 

585 """Display help information for available commands.""" 

586 from rich.panel import Panel 

587 

588 help_text = """[bold cyan]Available Commands:[/bold cyan] 

589 

590[bold yellow]Departure Queries:[/bold yellow] 

591 [green]<site_id>[/green] - Get departures for site (e.g., 1002) 

592 [green]site:<site_id>[/green] - Explicit site parameter 

593 [green]<site_id> line:<line_id>[/green] - Filter by line (e.g., 1002 line:17) 

594 [green]<site_id> direction:<1|2>[/green] - Filter by direction 

595 [green]<site_id> transport:<mode>[/green] - Filter by transport mode 

596 [green]<site_id> forecast:<minutes>[/green] - Set forecast window (default: 60) 

597 [green]<site_id> show_numbers:true[/green] - Show direction numbers 

598 [green]<site_id> debug:true[/green] - Show request/response headers 

599 

600[bold yellow]Site Lookup:[/bold yellow] 

601 [green]lookup:id <site_id>[/green] - Find site by ID (e.g., lookup:id 1002) 

602 [green]lookup:name <search_term>[/green] - Find sites by name (e.g., lookup:name odenplan) 

603 

604[bold yellow]Other Commands:[/bold yellow] 

605 [green]help[/green] - Show this help message 

606 [green]quit[/green] - Exit the program 

607 

608[bold yellow]Transport Modes:[/bold yellow] 

609 BUS, TRAM, METRO, TRAIN, FERRY, SHIP, TAXI 

610 

611[bold yellow]Examples:[/bold yellow] 

612 1002 - Basic departure lookup 

613 1002 line:17 direction:1 - Green line towards Åkeshov 

614 lookup:name central - Find stations with "central" in name 

615 lookup:id 9001 - Find specific station by ID""" 

616 

617 panel = Panel( 

618 help_text, 

619 title="[bold white]SL Transit REPL Help[/bold white]", 

620 border_style="blue", 

621 ) 

622 self.console.print(panel) 

623 

624 def run(self) -> None: 

625 """Run the interactive REPL session.""" 

626 self.console.print("[bold blue]SL Transport REPL[/bold blue]") 

627 self.console.print("Examples:") 

628 self.console.print( 

629 " 1002 (departure query: just site ID)" 

630 ) 

631 self.console.print( 

632 " 1002 line:17 direction:1 (departure query: with line and direction)" 

633 ) 

634 self.console.print( 

635 " lookup:id 1002 (site lookup: find site by ID)" 

636 ) 

637 self.console.print( 

638 " lookup:name odenplan (site lookup: find sites by name)" 

639 ) 

640 self.console.print( 

641 " help (show detailed help)" 

642 ) 

643 self.console.print("\nEnter 'quit' to exit") 

644 self.console.print("Use ↑/↓ arrows to access command history\n") 

645 

646 # Check sites data availability 

647 if not self.sites: 

648 self.console.print( 

649 "[red]Warning: No sites data available. Some features will be limited.[/red]" 

650 ) 

651 

652 try: 

653 self._run_loop() 

654 except KeyboardInterrupt: 

655 self.console.print("\n[yellow]Program interrupted by user[/yellow]") 

656 raise 

657 

658 def execute_query(self, query: str) -> bool: 

659 """Execute a single query and return whether it was successful. 

660 

661 Args: 

662 query: The query string to execute 

663 

664 Returns: 

665 True if query was executed successfully, False otherwise 

666 """ 

667 # Parse and validate query 

668 valid, result = self._parse_query(query) 

669 if not valid: 

670 if result["error"]: 

671 self.console.print(f"[red]{result['error']}[/red]") 

672 return False 

673 

674 # Handle special commands 

675 command = result.get("command") 

676 

677 if command == "help": 

678 self._show_help() 

679 return True 

680 

681 elif command == "lookup": 

682 lookup_type = result["lookup_type"] 

683 search_term = result["search_term"] 

684 

685 if lookup_type == "id": 

686 try: 

687 site_id = int(search_term) 

688 site = self._find_site_by_id(site_id) 

689 if site: 

690 table = self._create_site_table(site, f"Site with ID {site_id}") 

691 self.console.print(table) 

692 else: 

693 self.console.print( 

694 f"[red]No site found with ID {site_id}[/red]" 

695 ) 

696 return False 

697 except ValueError: 

698 self.console.print( 

699 "[red]Invalid ID format. Please enter a number.[/red]" 

700 ) 

701 return False 

702 

703 elif lookup_type == "name": 

704 sites = self._find_sites_by_substring(search_term) 

705 if sites: 

706 table = self._create_site_table( 

707 sites, f"Sites matching '{search_term}'" 

708 ) 

709 self.console.print(table) 

710 else: 

711 self.console.print( 

712 f"[red]No sites found containing '{search_term}'[/red]" 

713 ) 

714 return False 

715 

716 return True 

717 

718 # Build API parameters for departure queries 

719 api_params = {"forecast": result.get("forecast", self.DEFAULT_FORECAST)} 

720 

721 if "transport" in result: 

722 api_params["transport"] = result["transport"] 

723 if "direction" in result: 

724 api_params["direction"] = result["direction"] 

725 if "line" in result: 

726 api_params["line"] = result["line"] 

727 

728 # Get show_numbers preference 

729 show_direction_numbers = result.get("show_numbers", "").lower() == "true" 

730 

731 # Get debug preference 

732 debug = result.get("debug", "").lower() == "true" 

733 

734 for key, value in result.items(): 

735 if key.startswith("_"): 

736 api_params[key] = value 

737 

738 site_id = str(result["site"]) 

739 

740 # Get site info from our dictionary 

741 site_info = self._get_site_info(site_id) 

742 if not site_info: 

743 return False 

744 

745 # Fetch and display departures 

746 self._get_departures( 

747 int(site_id), api_params, show_direction_numbers, debug, site_info 

748 ) 

749 return True 

750 

751 def _run_loop(self) -> None: 

752 completer = self._create_completer() 

753 history = FileHistory(self.history_file) 

754 # prompt_text = "Enter query (site[:id] [transport:mode] [line:id] [direction:1|2] [forecast:minutes] [show_numbers:true|false] [debug:true|false]): " 

755 prompt_text = "Enter query (site[:id] [line:id] [forecast:minutes]): " 

756 

757 while True: 

758 # Get query with history support 

759 try: 

760 query = prompt( 

761 prompt_text, 

762 completer=completer, 

763 history=history, 

764 complete_while_typing=True, 

765 ).strip() 

766 except EOFError: 

767 break 

768 

769 if query.lower() == "quit": 

770 break 

771 

772 # Execute the query using the new method 

773 self.execute_query(query) 

774 

775 

776def main(): 

777 """Main entry point for the sl-repl CLI command.""" 

778 parser = argparse.ArgumentParser( 

779 description="SL Transit REPL - Query Stockholm's public transit departures", 

780 formatter_class=argparse.RawDescriptionHelpFormatter, 

781 epilog=""" 

782Examples: 

783 %(prog)s # Start interactive REPL 

784 %(prog)s "1002" # Get departures for site 1002 

785 %(prog)s "1002 line:17" # Get departures for line 17 at site 1002 

786 %(prog)s "lookup:name central" # Find sites containing 'central' 

787 %(prog)s "lookup:id 1002" # Get info for site ID 1002 

788 """.strip(), 

789 ) 

790 

791 parser.add_argument( 

792 "query", 

793 nargs="?", 

794 help="Query to execute (if not provided, starts interactive REPL)", 

795 ) 

796 

797 parser.add_argument( 

798 "--app-dir", 

799 help="Custom application directory path (default: ~/.sl_transit_repl)", 

800 ) 

801 

802 args = parser.parse_args() 

803 

804 try: 

805 repl = SLTransitREPL(app_dir=args.app_dir) 

806 

807 if args.query: 

808 # Non-interactive mode: execute single query 

809 success = repl.execute_query(args.query) 

810 sys.exit(0 if success else 1) 

811 else: 

812 # Interactive mode: start REPL 

813 repl.run() 

814 except KeyboardInterrupt: 

815 sys.exit(0) 

816 

817 

818if __name__ == "__main__": 

819 main()