Skip to content

Commit 09038ad

Browse files
authored
Merge pull request qodo-ai#530 from Codium-ai/tr/labels
Enhancement: Implement label case conversion and update label descriptions in settings files
2 parents bc2afba + b28829a commit 09038ad

File tree

5 files changed

+37
-6
lines changed

5 files changed

+37
-6
lines changed

pr_agent/algo/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,15 @@ def set_custom_labels(variables, git_provider=None):
379379

380380
# Set custom labels
381381
variables["custom_labels_class"] = "class Label(str, Enum):"
382+
counter = 0
383+
labels_minimal_to_labels_dict = {}
382384
for k, v in labels.items():
383-
description = v['description'].strip('\n').replace('\n', '\\n')
384-
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
385+
description = "'" + v['description'].strip('\n').replace('\n', '\\n') + "'"
386+
# variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = '{k}' # {description}"
387+
variables["custom_labels_class"] += f"\n {k.lower().replace(' ', '_')} = {description}"
388+
labels_minimal_to_labels_dict[k.lower().replace(' ', '_')] = k
389+
counter += 1
390+
variables["labels_minimal_to_labels_dict"] = labels_minimal_to_labels_dict
385391

386392
def get_user_labels(current_labels: List[str] = None):
387393
"""

pr_agent/settings/pr_custom_labels.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class Label(str, Enum):
3030
{%- endif %}
3131
3232
class Labels(BaseModel):
33-
labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.")
33+
labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.")
3434
======
3535
3636

pr_agent/settings/pr_description_prompts.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class FileWalkthrough(BaseModel):
3737
{%- endif %}
3838
3939
{%- if enable_semantic_files_types %}
40+
4041
Class FileDescription(BaseModel):
4142
filename: str = Field(description="the relevant file full path")
4243
changes_summary: str = Field(description="minimal and concise summary of the changes in the relevant file")
@@ -48,7 +49,7 @@ Class PRDescription(BaseModel):
4849
type: List[PRType] = Field(description="one or more types that describe the PR type. Return the label value, not the name.")
4950
description: str = Field(description="an informative and concise description of the PR. {%- if use_bullet_points %} Use bullet points.{% endif %}")
5051
{%- if enable_custom_labels %}
51-
labels: List[Label] = Field(min_items=0, description="custom labels that describe the PR. Return the label value, not the name.")
52+
labels: List[Label] = Field(min_items=0, description="choose the relevant custom labels that describe the PR content, and return their keys. Use the value field of the Label object to better understand the label meaning.")
5253
{%- endif %}
5354
{%- if enable_file_walkthrough %}
5455
main_files_walkthrough: List[FileWalkthrough] = Field(max_items=10)
@@ -69,8 +70,10 @@ type:
6970
- ...
7071
{%- if enable_custom_labels %}
7172
labels:
72-
- ...
73-
- ...
73+
- |
74+
...
75+
- |
76+
...
7477
{%- endif %}
7578
description: |-
7679
...

pr_agent/tools/pr_description.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ async def _get_prediction(self, model: str) -> str:
163163

164164
environment = Environment(undefined=StrictUndefined)
165165
set_custom_labels(variables, self.git_provider)
166+
self.variables = variables
166167
system_prompt = environment.from_string(get_settings().pr_description_prompt.system).render(variables)
167168
user_prompt = environment.from_string(get_settings().pr_description_prompt.user).render(variables)
168169

@@ -204,6 +205,16 @@ def _prepare_labels(self) -> List[str]:
204205
pr_types = self.data['type']
205206
elif type(self.data['type']) == str:
206207
pr_types = self.data['type'].split(',')
208+
209+
# convert lowercase labels to original case
210+
try:
211+
if "labels_minimal_to_labels_dict" in self.variables:
212+
d: dict = self.variables["labels_minimal_to_labels_dict"]
213+
for i, label_i in enumerate(pr_types):
214+
if label_i in d:
215+
pr_types[i] = d[label_i]
216+
except Exception as e:
217+
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
207218
return pr_types
208219

209220
def _prepare_pr_answer_with_markers(self) -> Tuple[str, str]:

pr_agent/tools/pr_generate_labels.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ async def _get_prediction(self, model: str) -> str:
136136

137137
environment = Environment(undefined=StrictUndefined)
138138
set_custom_labels(variables, self.git_provider)
139+
self.variables = variables
139140
system_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.system).render(variables)
140141
user_prompt = environment.from_string(get_settings().pr_custom_labels_prompt.user).render(variables)
141142

@@ -171,4 +172,14 @@ def _prepare_labels(self) -> List[str]:
171172
elif type(self.data['labels']) == str:
172173
pr_types = self.data['labels'].split(',')
173174

175+
# convert lowercase labels to original case
176+
try:
177+
if "labels_minimal_to_labels_dict" in self.variables:
178+
d: dict = self.variables["labels_minimal_to_labels_dict"]
179+
for i, label_i in enumerate(pr_types):
180+
if label_i in d:
181+
pr_types[i] = d[label_i]
182+
except Exception as e:
183+
get_logger().error(f"Error converting labels to original case {self.pr_id}: {e}")
184+
174185
return pr_types

0 commit comments

Comments
 (0)