Skip to content

Commit 5b39858

Browse files
committed
Alignerr project creation
1 parent 2fab8c9 commit 5b39858

17 files changed

+1793
-1
lines changed

libs/labelbox/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ dependencies = [
1212
"tqdm>=4.66.2",
1313
"geojson>=3.1.0",
1414
"lbox-clients==1.1.2",
15+
"PyYAML>=6.0",
1516
]
1617
readme = "README.md"
1718
requires-python = ">=3.9,<3.14"

libs/labelbox/src/labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
from labelbox.schema.ontology_kind import OntologyKind
7979
from labelbox.schema.organization import Organization
8080
from labelbox.schema.project import Project
81+
from labelbox.alignerr.schema.project_rate import ProjectRateV2 as ProjectRate
8182
from labelbox.schema.project_model_config import ProjectModelConfig
8283
from labelbox.schema.project_overview import (
8384
ProjectOverview,
@@ -98,7 +99,6 @@
9899
ResponseOption,
99100
PromptResponseClassification,
100101
)
101-
from lbox.exceptions import *
102102
from labelbox.schema.taskstatus import TaskStatus
103103
from labelbox.schema.api_key import ApiKey
104104
from labelbox.schema.timeunit import TimeUnit
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .alignerr_project import AlignerrWorkspace
2+
3+
__all__ = ['AlignerrWorkspace']
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
import datetime
2+
from enum import Enum
3+
from typing import TYPE_CHECKING, Optional
4+
import yaml
5+
from pathlib import Path
6+
7+
import logging
8+
9+
from labelbox.alignerr.schema.project_rate import BillingMode
10+
from labelbox.alignerr.schema.project_rate import ProjectRateInput
11+
from labelbox.alignerr.schema.project_rate import ProjectRateV2
12+
from labelbox.alignerr.schema.project_domain import ProjectDomain
13+
from labelbox.pagination import PaginatedCollection
14+
from labelbox.schema.media_type import MediaType
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
if TYPE_CHECKING:
20+
from labelbox import Client
21+
from labelbox.schema.project import Project
22+
from labelbox.alignerr.schema.project_domain import ProjectDomain
23+
24+
25+
class AlignerrRole(Enum):
26+
Labeler = "LABELER"
27+
Reviewer = "REVIEWER"
28+
Admin = "ADMIN"
29+
30+
31+
class AlignerrProject:
32+
def __init__(self, client: "Client", project: "Project", _internal: bool = False):
33+
if not _internal:
34+
raise RuntimeError(
35+
"AlignerrProject cannot be initialized directly. "
36+
"Use AlignerrProjectBuilder or AlignerrProjectFactory to create instances."
37+
)
38+
self.client = client
39+
self.project = project
40+
41+
@property
42+
def project(self) -> Optional["Project"]:
43+
return self._project
44+
45+
@project.setter
46+
def project(self, project: "Project"):
47+
self._project = project
48+
49+
50+
def domains(self) -> PaginatedCollection:
51+
"""Get all domains associated with this project.
52+
53+
Returns:
54+
PaginatedCollection of ProjectDomain instances
55+
"""
56+
return ProjectDomain.get_by_project_id(
57+
client=self.client,
58+
project_id=self.project.uid
59+
)
60+
61+
def add_domain(self, project_domain: ProjectDomain):
62+
return ProjectDomain.connect_project_to_domains(
63+
client=self.client,
64+
project_id=self.project.uid,
65+
domain_ids=[project_domain.uid]
66+
)
67+
68+
def get_project_rate(self) -> Optional["ProjectRateV2"]:
69+
return ProjectRateV2.get_by_project_id(
70+
client=self.client,
71+
project_id=self.project.uid
72+
)
73+
74+
def set_project_rate(self, project_rate_input: ProjectRateInput):
75+
return ProjectRateV2.set_project_rate(
76+
client=self.client,
77+
project_id=self.project.uid,
78+
project_rate_input=project_rate_input
79+
)
80+
81+
82+
83+
class AlignerrProjectBuilder:
84+
def __init__(self, client: "Client"):
85+
self.client = client
86+
self._alignerr_rates: dict[str, ProjectRateInput] = {}
87+
self._customer_rate: ProjectRateInput = None
88+
self._domains: list[ProjectDomain] = []
89+
self.role_name_to_id = self._get_role_name_to_id()
90+
91+
def set_name(self, name: str):
92+
self.project_name = name
93+
return self
94+
95+
def set_media_type(self, media_type: "MediaType"):
96+
self.project_media_type = media_type
97+
return self
98+
99+
def set_alignerr_role_rate(
100+
self,
101+
*,
102+
role_name: AlignerrRole,
103+
rate: float,
104+
billing_mode: BillingMode,
105+
effective_since: datetime.datetime,
106+
effective_until: Optional[datetime.datetime] = None,
107+
):
108+
if role_name.value not in self.role_name_to_id:
109+
raise ValueError(f"Role {role_name.value} not found")
110+
111+
role_id = self.role_name_to_id[role_name.value]
112+
role_name = role_name.value
113+
114+
# Convert datetime objects to ISO format strings
115+
effective_since_str = effective_since.isoformat() if isinstance(effective_since, datetime.datetime) else effective_since
116+
effective_until_str = effective_until.isoformat() if isinstance(effective_until, datetime.datetime) else effective_until
117+
118+
self._alignerr_rates[role_name] = ProjectRateInput(
119+
rateForId=role_id,
120+
isBillRate=False,
121+
billingMode=billing_mode,
122+
rate=rate,
123+
effectiveSince=effective_since_str,
124+
effectiveUntil=effective_until_str,
125+
)
126+
return self
127+
128+
def set_customer_rate(
129+
self,
130+
*,
131+
rate: float,
132+
billing_mode: BillingMode,
133+
effective_since: datetime.datetime,
134+
effective_until: Optional[datetime.datetime] = None,
135+
):
136+
# Convert datetime objects to ISO format strings
137+
effective_since_str = effective_since.isoformat() if isinstance(effective_since, datetime.datetime) else effective_since
138+
effective_until_str = effective_until.isoformat() if isinstance(effective_until, datetime.datetime) else effective_until
139+
140+
self._customer_rate = ProjectRateInput(
141+
rateForId="", # Empty string for customer rate
142+
isBillRate=True,
143+
billingMode=billing_mode,
144+
rate=rate,
145+
effectiveSince=effective_since_str,
146+
effectiveUntil=effective_until_str,
147+
)
148+
return self
149+
150+
def set_domains(self, domains: list[str]):
151+
for domain in domains:
152+
project_domain_page = ProjectDomain.search(self.client, search_by_name=domain)
153+
domain_result = project_domain_page.get_one()
154+
if domain_result is None:
155+
raise ValueError(f"Domain {domain} not found")
156+
self._domains.append(domain_result)
157+
return self
158+
159+
160+
def create(self, skip_validation: bool = False):
161+
if not skip_validation:
162+
self._validate()
163+
logger.info("Creating project")
164+
165+
project_data = {
166+
"name": self.project_name,
167+
"media_type": self.project_media_type,
168+
}
169+
labelbox_project = self.client.create_project(**project_data)
170+
alignerr_project = AlignerrProject(self.client, labelbox_project, _internal=True)
171+
172+
self._create_rates(alignerr_project)
173+
self._create_domains(alignerr_project)
174+
175+
return alignerr_project
176+
177+
def _create_rates(self, alignerr_project: AlignerrProject):
178+
for alignerr_role, project_rate in self._alignerr_rates.items():
179+
logger.info(f"Setting project rate for {alignerr_role}")
180+
alignerr_project.set_project_rate(project_rate)
181+
182+
def _create_domains(self, alignerr_project: AlignerrProject):
183+
if self._domains:
184+
logger.info(f"Setting domains: {[domain.name for domain in self._domains]}")
185+
domain_ids = [domain.uid for domain in self._domains]
186+
ProjectDomain.connect_project_to_domains(
187+
client=self.client,
188+
project_id=alignerr_project.project.uid,
189+
domain_ids=domain_ids
190+
)
191+
192+
def _validate_alignerr_rates(self):
193+
required_role_rates = set([AlignerrRole.Labeler.value, AlignerrRole.Reviewer.value])
194+
195+
for role_name in self._alignerr_rates.keys():
196+
required_role_rates.remove(role_name)
197+
if len(required_role_rates) > 0:
198+
raise ValueError(
199+
f"Required role rates are not set: {required_role_rates}"
200+
)
201+
202+
def _validate_customer_rate(self):
203+
if self._customer_rate is None:
204+
raise ValueError("Customer rate is not set")
205+
206+
def _validate(self):
207+
self._validate_alignerr_rates()
208+
self._validate_customer_rate()
209+
210+
def _get_role_name_to_id(self) -> dict[str, str]:
211+
roles = self.client.get_roles()
212+
return {role.name: role.uid for role in roles.values()}
213+
214+
215+
class AlignerrProjectFactory:
216+
def __init__(self, client: "Client"):
217+
self.client = client
218+
219+
def create(self, yaml_file_path: str, skip_validation: bool = False):
220+
"""
221+
Create an AlignerrProject from a YAML configuration file.
222+
223+
Args:
224+
yaml_file_path: Path to the YAML configuration file
225+
skip_validation: Whether to skip validation of required fields
226+
227+
Returns:
228+
AlignerrProject: The created project with configured rates
229+
230+
Raises:
231+
FileNotFoundError: If the YAML file doesn't exist
232+
yaml.YAMLError: If the YAML file is invalid
233+
ValueError: If required fields are missing or invalid
234+
"""
235+
logger.info(f"Creating project from YAML file: {yaml_file_path}")
236+
237+
# Load and parse YAML file
238+
yaml_path = Path(yaml_file_path)
239+
if not yaml_path.exists():
240+
raise FileNotFoundError(f"YAML file not found: {yaml_file_path}")
241+
242+
try:
243+
with open(yaml_path, 'r') as file:
244+
config = yaml.safe_load(file)
245+
except yaml.YAMLError as e:
246+
raise yaml.YAMLError(f"Invalid YAML file: {e}")
247+
248+
# Validate required fields
249+
if not config:
250+
raise ValueError("YAML file is empty")
251+
252+
required_fields = ['name', 'media_type']
253+
for field in required_fields:
254+
if field not in config:
255+
raise ValueError(f"Required field '{field}' is missing from YAML configuration")
256+
257+
# Create project builder
258+
builder = AlignerrProjectBuilder(self.client)
259+
260+
# Set basic project properties
261+
builder.set_name(config['name'])
262+
263+
# Set media type
264+
media_type_str = config['media_type']
265+
media_type = MediaType(media_type_str)
266+
267+
# Check if the media type is supported
268+
if not MediaType.is_supported(media_type):
269+
supported_members = MediaType.get_supported_members()
270+
raise ValueError(f"Invalid media_type '{media_type_str}'. Must be one of: {supported_members}")
271+
272+
builder.set_media_type(media_type)
273+
274+
# Set project rates if provided
275+
if 'rates' in config:
276+
rates_config = config['rates']
277+
if not isinstance(rates_config, dict):
278+
raise ValueError("'rates' must be a dictionary")
279+
280+
for role_name, rate_config in rates_config.items():
281+
try:
282+
alignerr_role = AlignerrRole(role_name.upper())
283+
except ValueError:
284+
raise ValueError(f"Invalid role '{role_name}'. Must be one of: {[r.value for r in AlignerrRole]}")
285+
286+
# Validate rate configuration
287+
required_rate_fields = ['rate', 'billing_mode', 'effective_since']
288+
for field in required_rate_fields:
289+
if field not in rate_config:
290+
raise ValueError(f"Required field '{field}' is missing for role '{role_name}'")
291+
292+
# Parse billing mode
293+
try:
294+
billing_mode = BillingMode(rate_config['billing_mode'])
295+
except ValueError:
296+
raise ValueError(f"Invalid billing_mode '{rate_config['billing_mode']}' for role '{role_name}'. Must be one of: {[e.value for e in BillingMode]}")
297+
298+
# Parse effective dates
299+
try:
300+
effective_since = datetime.datetime.fromisoformat(rate_config['effective_since'])
301+
except ValueError:
302+
raise ValueError(f"Invalid effective_since date format for role '{role_name}'. Use ISO format (YYYY-MM-DDTHH:MM:SS)")
303+
304+
effective_until = None
305+
if 'effective_until' in rate_config and rate_config['effective_until']:
306+
try:
307+
effective_until = datetime.datetime.fromisoformat(rate_config['effective_until'])
308+
except ValueError:
309+
raise ValueError(f"Invalid effective_until date format for role '{role_name}'. Use ISO format (YYYY-MM-DDTHH:MM:SS)")
310+
311+
# Set the rate
312+
builder.set_alignerr_role_rate(
313+
role_name=alignerr_role,
314+
rate=float(rate_config['rate']),
315+
billing_mode=billing_mode,
316+
effective_since=effective_since,
317+
effective_until=effective_until
318+
)
319+
320+
# Create the project
321+
return builder.create(skip_validation=skip_validation)
322+
323+
324+
class AlignerrWorkspace:
325+
def __init__(self, client: "Client"):
326+
self.client = client
327+
328+
def project_builder(self):
329+
return AlignerrProjectBuilder(self.client)
330+
331+
def project_prototype(self):
332+
return AlignerrProjectFactory(self.client)
333+
334+

libs/labelbox/src/labelbox/alignerr/schema/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)