Luca Latini commited on
Commit
d2ece60
·
1 Parent(s): 81f7e58

Refactor EnrollmentRule handling of GROUPS and COURSES

Browse files
Files changed (1) hide show
  1. app.py +20 -57
app.py CHANGED
@@ -1,8 +1,5 @@
1
  import gradio as gr
2
- import pydantic
3
- import random
4
-
5
- from pydantic import Field, field_validator, BaseModel
6
  from enum import Enum
7
  from typing import List
8
 
@@ -12,8 +9,6 @@ class Group(str, Enum):
12
  marketing = "marketing"
13
  leadership = "leadership"
14
 
15
- GROUPS = ["onboarding", "sales enablement", "marketing", "leadership"]
16
-
17
  class Course(str, Enum):
18
  company_culture_and_values = "company culture and values"
19
  product_knowledge_and_features = "product knowledge and features"
@@ -32,89 +27,57 @@ class Course(str, Enum):
32
  change_management_and_organizational_development = "change management and organizational development"
33
  emotional_intelligence_and_communication_skills = "emotional intelligence and communication skills"
34
 
35
- COURSES = ["company culture and values",
36
- "product knowledge and features",
37
- "sales process and methodology",
38
- "customer relationship management system training",
39
- "sales enablement strategy and planning",
40
- "content creation and curation",
41
- "sales coaching and mentoring",
42
- "sales metrics and analytics",
43
- "digital marketing fundamentals",
44
- "content marketing strategy",
45
- "social media marketing",
46
- "search engine optimization",
47
- "leadership development and coaching",
48
- "strategic planning and decision making",
49
- "change management and organizational development",
50
- "emotional intelligence and communication skills"
51
- ]
52
-
53
  class EnrollmentRule(BaseModel):
54
  rule_code: str = Field(description="unique identifier code for the rule", default="")
55
  rule_name: str = Field(description="name of the rule", default="")
56
  group: Group = Field(description="group to apply the rule to", default=Group.onboarding)
57
- courses: List[Course] = Field(description="list of courses that the members of the group will follow", default=[course.value for course in Course])
 
 
 
58
 
59
  @field_validator("rule_code")
60
  def rule_code_max_length(cls, rule_code):
61
- # Check if empty
62
  if not rule_code.strip():
63
  raise ValueError("The rule code cannot be empty.")
64
- # Check length
65
  if len(rule_code) > 50:
66
  raise ValueError("The rule code must contain less than 50 characters.")
67
  return rule_code
 
68
  @field_validator("rule_name")
69
  def rule_name_max_length(cls, rule_name):
70
- # Check if empty
71
  if not rule_name.strip():
72
  raise ValueError("The rule name cannot be empty.")
73
- # Check length
74
  if len(rule_name) > 255:
75
  raise ValueError("The rule name must contain less than 255 characters.")
76
  return rule_name
77
- @field_validator("group")
78
- def group_is_valid(cls, group):
79
- # Check if empty
80
- if not group or not group.strip():
81
- raise ValueError("Group cannot be empty.")
82
- # Validate group value
83
- if group.lower().strip() not in GROUPS:
84
- raise ValueError(f"group '{group}' is not a valid value for field 'group'. 'group' must be one of the following: '{', '.join(GROUPS)}'")
85
- return group
86
  @field_validator("courses")
87
- def course_list_is_valid(cls, courses):
88
- # Ensure at least one course is selected
89
  if not courses:
90
  raise ValueError("At least one course must be provided.")
91
-
92
- # Validate all listed courses
93
- invalid_courses = []
94
- for course in courses:
95
- if course.lower().strip() not in COURSES:
96
- invalid_courses.append(course)
97
- if invalid_courses:
98
- raise ValueError(f"course(s) '{', '.join(invalid_courses)}' is not a valid value for field 'courses'. 'course' must be one or more of the following: '{', '.join(COURSES)}'")
99
  return courses
100
-
101
  def create_enrollment_rule(enrollment_rule: EnrollmentRule) -> dict:
102
- """Create an enrollment rule, based on the rule parameters provided by the user"""
103
- payload = enrollment_rule.model_dump(mode="json")
104
- return payload
105
 
106
- # A small Gradio-friendly wrapper that converts user inputs into an EnrollmentRule
107
  def gr_create_enrollment_rule(rule_code, rule_name, group, courses):
108
- # Build the Pydantic model from the raw inputs
 
 
 
109
  enrollment_rule = EnrollmentRule(
110
  rule_code=rule_code,
111
  rule_name=rule_name,
112
- group=group,
113
- courses=courses
114
  )
115
  return create_enrollment_rule(enrollment_rule)
116
 
117
- # Define the Gradio interface
 
 
 
118
  demo = gr.Interface(
119
  fn=gr_create_enrollment_rule,
120
  inputs=[
 
1
  import gradio as gr
2
+ from pydantic import BaseModel, Field, field_validator
 
 
 
3
  from enum import Enum
4
  from typing import List
5
 
 
9
  marketing = "marketing"
10
  leadership = "leadership"
11
 
 
 
12
  class Course(str, Enum):
13
  company_culture_and_values = "company culture and values"
14
  product_knowledge_and_features = "product knowledge and features"
 
27
  change_management_and_organizational_development = "change management and organizational development"
28
  emotional_intelligence_and_communication_skills = "emotional intelligence and communication skills"
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class EnrollmentRule(BaseModel):
31
  rule_code: str = Field(description="unique identifier code for the rule", default="")
32
  rule_name: str = Field(description="name of the rule", default="")
33
  group: Group = Field(description="group to apply the rule to", default=Group.onboarding)
34
+ courses: List[Course] = Field(
35
+ description="list of courses that the members of the group will follow",
36
+ default_factory=lambda: list(Course)
37
+ )
38
 
39
  @field_validator("rule_code")
40
  def rule_code_max_length(cls, rule_code):
 
41
  if not rule_code.strip():
42
  raise ValueError("The rule code cannot be empty.")
 
43
  if len(rule_code) > 50:
44
  raise ValueError("The rule code must contain less than 50 characters.")
45
  return rule_code
46
+
47
  @field_validator("rule_name")
48
  def rule_name_max_length(cls, rule_name):
 
49
  if not rule_name.strip():
50
  raise ValueError("The rule name cannot be empty.")
 
51
  if len(rule_name) > 255:
52
  raise ValueError("The rule name must contain less than 255 characters.")
53
  return rule_name
54
+
 
 
 
 
 
 
 
 
55
  @field_validator("courses")
56
+ def course_list_not_empty(cls, courses):
 
57
  if not courses:
58
  raise ValueError("At least one course must be provided.")
 
 
 
 
 
 
 
 
59
  return courses
60
+
61
  def create_enrollment_rule(enrollment_rule: EnrollmentRule) -> dict:
62
+ return enrollment_rule.model_dump(mode="json")
 
 
63
 
 
64
  def gr_create_enrollment_rule(rule_code, rule_name, group, courses):
65
+ # Convert group and courses from strings to enum values
66
+ # Pydantic will handle this automatically, but we need to ensure the raw inputs are enums.
67
+ # If the input is directly from the Gradio UI, they will be strings. Pydantic tries to coerce
68
+ # them into enum values. If it fails, it will raise a validation error.
69
  enrollment_rule = EnrollmentRule(
70
  rule_code=rule_code,
71
  rule_name=rule_name,
72
+ group=Group(group), # Will raise ValueError if invalid
73
+ courses=[Course(c) for c in courses] # Will raise ValueError if invalid
74
  )
75
  return create_enrollment_rule(enrollment_rule)
76
 
77
+ # Derive the choices from the Enums directly:
78
+ GROUPS = [g.value for g in Group]
79
+ COURSES = [c.value for c in Course]
80
+
81
  demo = gr.Interface(
82
  fn=gr_create_enrollment_rule,
83
  inputs=[