Skip to content

Commit 5917322

Browse files
committed
PLT-3393: invite users with user_group_roles
1 parent 3a7fb21 commit 5917322

File tree

7 files changed

+239
-21
lines changed

7 files changed

+239
-21
lines changed

libs/labelbox/src/labelbox/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from labelbox.schema.tool_building.step_reasoning_tool import StepReasoningTool
5656
from labelbox.schema.tool_building.prompt_issue_tool import PromptIssueTool
5757
from labelbox.schema.tool_building.relationship_tool import RelationshipTool
58-
from labelbox.schema.role import Role, ProjectRole
58+
from labelbox.schema.role import Role, ProjectRole, UserGroupRole
5959
from labelbox.schema.invite import Invite, InviteLimit
6060
from labelbox.schema.data_row_metadata import (
6161
DataRowMetadataOntology,

libs/labelbox/src/labelbox/client.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -507,16 +507,16 @@ def delete_project_memberships(
507507
self, project_id: str, user_ids: list[str]
508508
) -> dict:
509509
"""Deletes project memberships for one or more users.
510-
510+
511511
Args:
512512
project_id (str): ID of the project
513513
user_ids (list[str]): List of user IDs to remove from the project
514-
514+
515515
Returns:
516516
dict: Result containing:
517517
- success (bool): True if operation succeeded
518518
- errorMessage (str or None): Error message if operation failed
519-
519+
520520
Example:
521521
>>> result = client.delete_project_memberships(
522522
>>> project_id="project123",
@@ -539,12 +539,12 @@ def delete_project_memberships(
539539
errorMessage
540540
}
541541
}"""
542-
542+
543543
params = {
544544
"projectId": project_id,
545545
"userIds": user_ids,
546546
}
547-
547+
548548
result = self.execute(mutation, params)
549549
return result["deleteProjectMemberships"]
550550

libs/labelbox/src/labelbox/orm/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ class Entity(metaclass=EntityMeta):
399399
CatalogSlice: Type[labelbox.CatalogSlice]
400400
ModelSlice: Type[labelbox.ModelSlice]
401401
TaskQueue: Type[labelbox.TaskQueue]
402+
UserGroupRole: Type[labelbox.UserGroupRole]
402403

403404
@classmethod
404405
def _attributes_of_type(cls, attr_type):

libs/labelbox/src/labelbox/schema/organization.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import TYPE_CHECKING, Dict, List, Optional, Union
1+
from typing import TYPE_CHECKING, Dict, List, Set, Optional, Union
22

33
from lbox.exceptions import LabelboxError
44

@@ -22,6 +22,7 @@
2222
ProjectRole,
2323
Role,
2424
User,
25+
UserGroupRole,
2526
)
2627

2728

@@ -65,6 +66,7 @@ def invite_user(
6566
email: str,
6667
role: "Role",
6768
project_roles: Optional[List["ProjectRole"]] = None,
69+
user_group_roles: Optional[List["UserGroupRole"]] = None,
6870
) -> "Invite":
6971
"""
7072
Invite a new member to the org. This will send the user an email invite
@@ -88,6 +90,40 @@ def invite_user(
8890
f"Project roles cannot be set for a user with organization level permissions. Found role name `{role.name}`, expected `NONE`"
8991
)
9092

93+
if user_group_roles and role.name != "NONE":
94+
raise ValueError(
95+
f"User Group roles cannot be set for a user with organization level permissions. Found role name `{role.name}`, expected `NONE`"
96+
)
97+
98+
if user_group_roles:
99+
# The backend can 500 if the same groupId appears more than once.
100+
# We dedupe exact duplicates (same groupId+roleId), but reject
101+
# conflicting assignments (same groupId with different roleId).
102+
103+
deduped_user_group_roles: Dict[str, "UserGroupRole"] = {}
104+
conflicting_user_group_ids: Set[str] = set()
105+
106+
for user_group_role in user_group_roles:
107+
user_group_id = user_group_role.user_group.id
108+
role_id = user_group_role.role.uid
109+
110+
existing = deduped_user_group_roles.get(user_group_id)
111+
if existing is None:
112+
deduped_user_group_roles[user_group_id] = user_group_role
113+
else:
114+
if existing.role.uid != role_id:
115+
conflicting_user_group_ids.add(user_group_id)
116+
117+
if conflicting_user_group_ids:
118+
conflicts_str = ", ".join(sorted(conflicting_user_group_ids))
119+
raise ValueError(
120+
"user_group_roles contains conflicting role assignments for "
121+
"the same UserGroup. Each UserGroup may only appear once. "
122+
f"Conflicting user_group.id values: {conflicts_str}"
123+
)
124+
125+
user_group_roles = list(deduped_user_group_roles.values())
126+
91127
data_param = "data"
92128
query_str = """mutation createInvitesPyApi($%s: [CreateInviteInput!]){
93129
createInvites(data: $%s){ invite { id createdAt organizationRoleName inviteeEmail inviter { %s } }}}""" % (
@@ -104,6 +140,19 @@ def invite_user(
104140
for project_role in project_roles or []
105141
]
106142

143+
user_group_ids = [
144+
user_group_role.user_group.id
145+
for user_group_role in user_group_roles or []
146+
]
147+
148+
user_group_with_role_ids = [
149+
{
150+
"groupId": user_group_role.user_group.id,
151+
"roleId": user_group_role.role.uid,
152+
}
153+
for user_group_role in user_group_roles or []
154+
]
155+
107156
res = self.client.execute(
108157
query_str,
109158
{
@@ -114,6 +163,8 @@ def invite_user(
114163
"organizationId": self.uid,
115164
"organizationRoleId": role.uid,
116165
"projects": projects,
166+
"userGroupIds": user_group_ids,
167+
"userGroupWithRoleIds": user_group_with_role_ids,
117168
}
118169
]
119170
},

libs/labelbox/src/labelbox/schema/project.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,9 @@ def get_resource_tags(self) -> List[ResourceTag]:
317317

318318
return [ResourceTag(self.client, tag) for tag in results]
319319

320-
def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedCollection:
320+
def labels(
321+
self, datasets=None, order_by=None, created_by=None
322+
) -> PaginatedCollection:
321323
"""Custom relationship expansion method to support limited filtering.
322324
323325
Args:
@@ -334,7 +336,7 @@ def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedColl
334336
Example:
335337
>>> # Get all labels
336338
>>> all_labels = project.labels()
337-
>>>
339+
>>>
338340
>>> # Get labels by specific user
339341
>>> user_labels = project.labels(created_by=user_id)
340342
>>> # or
@@ -351,16 +353,22 @@ def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedColl
351353

352354
# Build where clause
353355
where_clauses = []
354-
356+
355357
if datasets is not None:
356-
dataset_ids = ", ".join('"%s"' % dataset.uid for dataset in datasets)
357-
where_clauses.append(f"dataRow: {{dataset: {{id_in: [{dataset_ids}]}}}}")
358-
358+
dataset_ids = ", ".join(
359+
'"%s"' % dataset.uid for dataset in datasets
360+
)
361+
where_clauses.append(
362+
f"dataRow: {{dataset: {{id_in: [{dataset_ids}]}}}}"
363+
)
364+
359365
if created_by is not None:
360366
# Handle both User object and user_id string
361-
user_id = created_by.uid if hasattr(created_by, 'uid') else created_by
367+
user_id = (
368+
created_by.uid if hasattr(created_by, "uid") else created_by
369+
)
362370
where_clauses.append(f'createdBy: {{id: "{user_id}"}}')
363-
371+
364372
if where_clauses:
365373
where = " where:{" + ", ".join(where_clauses) + "}"
366374
else:
@@ -396,7 +404,7 @@ def labels(self, datasets=None, order_by=None, created_by=None) -> PaginatedColl
396404

397405
def delete_labels_by_user(self, user_id: str) -> int:
398406
"""Soft deletes all labels created by a specific user in this project.
399-
407+
400408
This performs a soft delete (sets deleted=true in the database).
401409
The labels will no longer appear in queries but remain in the database.
402410
Labels are deleted in chunks of 500 to avoid overwhelming the API.
@@ -413,18 +421,18 @@ def delete_labels_by_user(self, user_id: str) -> int:
413421
>>> print(f"Deleted {deleted_count} labels")
414422
"""
415423
labels_to_delete = list(self.labels(created_by=user_id))
416-
424+
417425
if not labels_to_delete:
418426
return 0
419-
427+
420428
chunk_size = 500
421429
total_deleted = 0
422-
430+
423431
for i in range(0, len(labels_to_delete), chunk_size):
424-
chunk = labels_to_delete[i:i + chunk_size]
432+
chunk = labels_to_delete[i : i + chunk_size]
425433
Entity.Label.bulk_delete(chunk)
426434
total_deleted += len(chunk)
427-
435+
428436
return total_deleted
429437

430438
def export(

libs/labelbox/src/labelbox/schema/role.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
if TYPE_CHECKING:
88
from labelbox import Client, Project
9+
from labelbox.schema.user_group import UserGroup
910

1011
_ROLES: Optional[Dict[str, "Role"]] = None
1112

@@ -45,3 +46,9 @@ class UserRole(Role): ...
4546
class ProjectRole:
4647
project: "Project"
4748
role: Role
49+
50+
51+
@dataclass
52+
class UserGroupRole:
53+
user_group: "UserGroup"
54+
role: Role
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
import pytest
2+
from types import SimpleNamespace
3+
from unittest.mock import MagicMock
4+
5+
from labelbox.schema.role import UserGroupRole
6+
from labelbox.schema.organization import Organization
7+
8+
9+
def test_invite_user_duplicate_user_group_roles_same_role_is_deduped():
10+
client = MagicMock()
11+
client.get_user.return_value = SimpleNamespace(uid="inviter-id")
12+
client.execute.return_value = {
13+
"createInvites": [
14+
{
15+
"invite": {
16+
"id": "invite-id",
17+
"createdAt": "2020-01-01T00:00:00.000Z",
18+
"organizationRoleName": "NONE",
19+
"inviteeEmail": "someone@example.com",
20+
"inviter": {"id": "inviter-id"},
21+
}
22+
}
23+
]
24+
}
25+
26+
organization = Organization(
27+
client,
28+
{
29+
"id": "org-id",
30+
"name": "Test Org",
31+
"createdAt": "2020-01-01T00:00:00.000Z",
32+
"updatedAt": "2020-01-01T00:00:00.000Z",
33+
},
34+
)
35+
36+
org_role_none = SimpleNamespace(uid="org-role-none-id", name="NONE")
37+
reviewer_role = SimpleNamespace(uid="reviewer-role-id", name="REVIEWER")
38+
user_group = SimpleNamespace(id="user-group-id")
39+
40+
user_group_roles = [
41+
UserGroupRole(user_group=user_group, role=reviewer_role),
42+
UserGroupRole(user_group=user_group, role=reviewer_role),
43+
]
44+
45+
organization.invite_user(
46+
email="someone@example.com",
47+
role=org_role_none,
48+
user_group_roles=user_group_roles,
49+
)
50+
51+
# ensure we only send one entry per group
52+
args, kwargs = client.execute.call_args
53+
assert kwargs == {}
54+
payload = args[1]["data"][0]
55+
assert payload["userGroupIds"] == ["user-group-id"]
56+
assert payload["userGroupWithRoleIds"] == [
57+
{"groupId": "user-group-id", "roleId": "reviewer-role-id"}
58+
]
59+
60+
61+
def test_invite_user_duplicate_user_group_roles_conflicting_roles_raises_value_error():
62+
client = MagicMock()
63+
client.get_user.return_value = SimpleNamespace(uid="inviter-id")
64+
65+
organization = Organization(
66+
client,
67+
{
68+
"id": "org-id",
69+
"name": "Test Org",
70+
"createdAt": "2020-01-01T00:00:00.000Z",
71+
"updatedAt": "2020-01-01T00:00:00.000Z",
72+
},
73+
)
74+
75+
org_role_none = SimpleNamespace(uid="org-role-none-id", name="NONE")
76+
reviewer_role = SimpleNamespace(uid="reviewer-role-id", name="REVIEWER")
77+
team_manager_role = SimpleNamespace(
78+
uid="team-manager-role-id", name="TEAM_MANAGER"
79+
)
80+
user_group = SimpleNamespace(id="user-group-id")
81+
82+
user_group_roles = [
83+
UserGroupRole(user_group=user_group, role=reviewer_role),
84+
UserGroupRole(user_group=user_group, role=team_manager_role),
85+
]
86+
87+
with pytest.raises(ValueError, match="conflicting role assignments"):
88+
organization.invite_user(
89+
email="someone@example.com",
90+
role=org_role_none,
91+
user_group_roles=user_group_roles,
92+
)
93+
94+
client.execute.assert_not_called()
95+
96+
97+
def test_invite_user_user_group_roles_payload_contains_all_groups():
98+
client = MagicMock()
99+
client.get_user.return_value = SimpleNamespace(uid="inviter-id")
100+
client.execute.return_value = {
101+
"createInvites": [
102+
{
103+
"invite": {
104+
"id": "invite-id",
105+
"createdAt": "2020-01-01T00:00:00.000Z",
106+
"organizationRoleName": "NONE",
107+
"inviteeEmail": "someone@example.com",
108+
"inviter": {"id": "inviter-id"},
109+
}
110+
}
111+
]
112+
}
113+
114+
organization = Organization(
115+
client,
116+
{
117+
"id": "org-id",
118+
"name": "Test Org",
119+
"createdAt": "2020-01-01T00:00:00.000Z",
120+
"updatedAt": "2020-01-01T00:00:00.000Z",
121+
},
122+
)
123+
124+
org_role_none = SimpleNamespace(uid="org-role-none-id", name="NONE")
125+
reviewer_role = SimpleNamespace(uid="reviewer-role-id", name="REVIEWER")
126+
team_manager_role = SimpleNamespace(
127+
uid="team-manager-role-id", name="TEAM_MANAGER"
128+
)
129+
130+
ug1 = SimpleNamespace(id="user-group-1")
131+
ug2 = SimpleNamespace(id="user-group-2")
132+
133+
user_group_roles = [
134+
UserGroupRole(user_group=ug1, role=reviewer_role),
135+
UserGroupRole(user_group=ug2, role=team_manager_role),
136+
]
137+
138+
organization.invite_user(
139+
email="someone@example.com",
140+
role=org_role_none,
141+
user_group_roles=user_group_roles,
142+
)
143+
144+
args, kwargs = client.execute.call_args
145+
assert kwargs == {}
146+
payload = args[1]["data"][0]
147+
assert payload["userGroupIds"] == ["user-group-1", "user-group-2"]
148+
assert payload["userGroupWithRoleIds"] == [
149+
{"groupId": "user-group-1", "roleId": "reviewer-role-id"},
150+
{"groupId": "user-group-2", "roleId": "team-manager-role-id"},
151+
]

0 commit comments

Comments
 (0)