Source code for src.pipeline.prelabelling.grounding_dino_prelabelling

import os
import json
import warnings
import cv2
from pathlib import Path
from typing import List, Dict
from tqdm import tqdm
import contextlib
import io
from utils import detect_device

# Suppress warnings for clean output
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# Default detection classes and thresholds
TEXT_PROMPTS = ["fire", "smoke", "person", "vehicle", "lightning"]
BOX_THRESHOLD = 0.3
TEXT_THRESHOLD = 0.25


def _get_image_files(directory: Path) -> List[Path]:
    """
    Get all image files (jpg, jpeg, png) from the given directory.
    """
    return [f for f in directory.glob("*") if f.suffix.lower() in {".jpg", ".jpeg", ".png"}]


[docs] def generate_gd_prelabelling( raw_dir: Path, output_dir: Path, config: Dict, model_weights: Path, config_path: Path, text_prompts: List[str] = TEXT_PROMPTS, box_threshold: float = None, text_threshold: float = None ) -> None: """ Run Grounding DINO to detect objects from images using text prompts. Saves a JSON file per image with predicted bounding boxes and metadata. This version runs sequentially (no multiprocessing). """ print(f"Using device: {config.get('torch_device', 'cpu')}") # Get thresholds from config or fallback to defaults box_threshold = config.get("dino_box_threshold", BOX_THRESHOLD) text_threshold = config.get("dino_text_threshold", TEXT_THRESHOLD) # Get list of all images to process image_files = _get_image_files(raw_dir) print(f"Found {len(image_files)} images to process") # Create output directory if it doesn't exist output_dir.mkdir(parents=True, exist_ok=True) # Import Grounding DINO model here to keep top-level clean from groundingdino.util.inference import Model device = config.get("torch_device", "auto") if device == "auto": device = detect_device() # Load the model once with contextlib.redirect_stdout(io.StringIO()): # Suppress internal model logs model = Model( model_config_path=str(config_path), model_checkpoint_path=str(model_weights), device=device ) # Track results for summary success_count = 0 skipped_count = 0 error_count = 0 # Process each image one by one for image_path in tqdm(image_files, desc="Processing images"): image_name = image_path.stem image = cv2.imread(str(image_path)) if image is None: skipped_count += 1 continue try: # Run inference with class prompts detections = model.predict_with_classes( image=image, classes=text_prompts, box_threshold=box_threshold, text_threshold=text_threshold ) # Format output metadata = [] for box, score, class_id in zip(detections.xyxy, detections.confidence, detections.class_id): label = text_prompts[class_id] if class_id is not None else "unknown" metadata.append({ "image": image_name, "class": label, "confidence": float(score), "bbox": [float(x) for x in box], "source": "grounding_dino" }) # Save results to JSON output_file = output_dir / f"{image_name}.json" with open(output_file, "w") as f: json.dump({"predictions": metadata}, f, indent=2) success_count += 1 except Exception as e: error_count += 1 continue # Print summary print("\nPrediction Summary:") print(f"Successfully processed: {success_count} images") print(f"Skipped (unreadable): {skipped_count} images") print(f"Failed to process: {error_count} images")