Source code for geococo.coco_models

from __future__ import annotations
import numpy as np
import pathlib
import pandas as pd
from pandas import Series
from datetime import datetime
from typing import List, Optional, Dict, Any
from typing_extensions import TypedDict
from pydantic import BaseModel, root_validator
from pydantic.fields import Field
from semver.version import Version


[docs] class CocoDataset(BaseModel): info: Info images: List[Image] = [] annotations: List[Annotation] = [] categories: List[Category] = [] sources: List[Source] = [] next_image_id: int = Field(default=1, exclude=True) next_annotation_id: int = Field(default=1, exclude=True) next_source_id: int = Field(default=1, exclude=True)
[docs] @root_validator def _set_ids(cls: CocoDataset, values: Dict[str, Any]) -> Dict[str, Any]: values["next_image_id"] = len(values["images"]) + 1 values["next_annotation_id"] = len(values["annotations"]) + 1 values["next_source_id"] = len(values["sources"]) return values
[docs] def add_annotation(self, annotation: Annotation) -> None: self.annotations.append(annotation) self.next_annotation_id += 1
[docs] def add_image(self, image: Image) -> None: self.images.append(image) self.next_image_id += 1
[docs] def add_source(self, source_path: pathlib.Path, date_captured: datetime) -> None: sources = [ssrc for ssrc in self.sources if ssrc.file_name == source_path] if sources: assert len(sources) == 1 source = sources[0] self.bump_version(bump_method="patch") else: source = Source( id=len(self.sources) + 1, file_name=source_path, date_captured=date_captured, ) self.sources.append(source) self.bump_version(bump_method="minor") self.next_source_id = source.id
[docs] def add_categories( self, category_ids: Optional[Series], category_names: Optional[Series], super_names: Optional[Series], ) -> None: # initializing values super_default = "1" names_present = ids_present = False # Loading all existing Category instances as a single dataframe category_pd = pd.DataFrame( [category.dict() for category in self.categories], columns=Category.schema()["properties"].keys(), ) # checking if names can be assigned to uid_array (used to check duplicates) if category_names is not None: category_names: np.ndarray = category_names.to_numpy() uid_array = category_names uid_attribute = "name" names_present = True # checking if ids can be assigned to uid_array (used to check duplicates) if category_ids is not None: category_ids: np.ndarray = category_ids.to_numpy() uid_array = category_ids # overrides existing array because ids are leading uid_attribute = "id" ids_present = True if not names_present and not ids_present: raise AttributeError("At least one category attribute must be present") # masking out duplicate values and exiting if all duplicates original_shape = uid_array.shape _, indices = np.unique(uid_array, return_index=True) uid_array = uid_array[indices] member_mask = np.isin(uid_array, category_pd[uid_attribute]) new_members = uid_array[~member_mask] new_shape = new_members.shape if new_shape[0] == 0: return # creating default supercategory_names if not given if super_names is None: super_names = np.full(shape=new_shape, fill_value=super_default) # type: ignore[assignment] else: super_names: np.ndarray = super_names.to_numpy() assert super_names.shape == original_shape super_names = super_names[indices][~member_mask] # creating default category_names if not given (str version of ids) if ids_present and not names_present: category_names = new_members.astype(str) category_ids = new_members # creating ids if not given (incremental sequence starting from last known id) elif names_present and not ids_present: pandas_mask = category_pd[uid_attribute].isin(uid_array[member_mask]) max_id = category_pd.loc[pandas_mask, "id"].max() start = np.nansum([max_id, 1]) end = start + new_members.size category_ids = np.arange(start, end) # type: ignore[assignment] category_names = new_members # ensuring equal size for category names and ids (if given) else: assert category_names.shape == original_shape # type: ignore[union-attr] category_names = category_names[indices][~member_mask] # type: ignore[index] category_ids = new_members # iteratively instancing and appending Category from set ids, names and supers category_info = zip(category_ids, category_names, super_names) for cid, name, super in category_info: category = Category(id=cid, name=name, supercategory=super) self.categories.append(category)
[docs] def bump_version(self, bump_method: str) -> None: bump_methods = ["patch", "minor", "major"] version = Version.parse(self.info.version) if bump_method not in bump_methods: raise ValueError(f"bump_method needs to be one of {bump_methods}") elif bump_method == bump_methods[0]: version = version.bump_patch() elif bump_method == bump_methods[1]: version = version.bump_minor() else: version = version.bump_major() self.info.version = str(version)
[docs] def verify_used_dir(self, images_dir: pathlib.Path) -> None: output_dirs = np.unique([image.file_name.parent for image in self.images]) if images_dir not in output_dirs: self.bump_version(bump_method="major")
[docs] class Info(BaseModel): version: str = str(Version(major=0)) year: Optional[int] = None description: Optional[str] = None contributor: Optional[str] = None date_created: Optional[datetime] = None
[docs] class Image(BaseModel): id: int width: int height: int file_name: pathlib.Path source_id: int date_captured: datetime
[docs] class Annotation(BaseModel): id: int image_id: int category_id: int segmentation: RleDict area: float bbox: List[int] iscrowd: int
[docs] class Category(BaseModel): id: int name: str supercategory: str
[docs] class RleDict(TypedDict): size: List[int] counts: bytes
[docs] class Source(BaseModel): id: int file_name: pathlib.Path date_captured: datetime
# Call update_forward_refs() to resolve forward references (for pydantic <2.0.0) CocoDataset.update_forward_refs() Info.update_forward_refs() Image.update_forward_refs() Annotation.update_forward_refs() Category.update_forward_refs() Source.update_forward_refs()