Skip to content

Commit 0a3f8ef

Browse files
authored
feat: add guardrails to model card report (#445)
Signed-off-by: Ruben Romero Montes <[email protected]>
1 parent 66600f1 commit 0a3f8ef

File tree

10 files changed

+494
-8
lines changed

10 files changed

+494
-8
lines changed

pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
<deploy-plugin.version>3.1.1</deploy-plugin.version>
4949

5050
<!-- Dependencies -->
51-
<exhort-api.version>1.0.16</exhort-api.version>
51+
<exhort-api.version>1.0.17</exhort-api.version>
5252
<sentry.version>7.8.0</sentry.version>
5353
<spdx.version>2.0.2</spdx.version>
5454
<htmlunit.version>4.11.1</htmlunit.version>
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Copyright 2025 Red Hat, Inc. and/or its affiliates
3+
* and other contributors as indicated by the @author tags.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
*
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package com.redhat.exhort.model.modelcards;
20+
21+
import java.util.List;
22+
import java.util.Set;
23+
24+
import jakarta.persistence.Column;
25+
import jakarta.persistence.Entity;
26+
import jakarta.persistence.EnumType;
27+
import jakarta.persistence.Enumerated;
28+
import jakarta.persistence.FetchType;
29+
import jakarta.persistence.GeneratedValue;
30+
import jakarta.persistence.Id;
31+
import jakarta.persistence.JoinColumn;
32+
import jakarta.persistence.JoinTable;
33+
import jakarta.persistence.ManyToMany;
34+
import jakarta.persistence.Table;
35+
36+
@Entity
37+
@Table(name = "guardrail")
38+
public class Guardrail {
39+
40+
@Id @GeneratedValue public Long id;
41+
42+
public String name;
43+
44+
public String description;
45+
46+
@Enumerated(EnumType.STRING)
47+
public GuardrailScope scope;
48+
49+
@ManyToMany(fetch = FetchType.EAGER)
50+
@JoinTable(
51+
name = "guardrail_metrics",
52+
joinColumns = @JoinColumn(name = "guardrail_id"),
53+
inverseJoinColumns = @JoinColumn(name = "task_metric_id"))
54+
public Set<TaskMetric> metrics;
55+
56+
@Column(name = "external_references")
57+
public List<String> references;
58+
59+
@Column(name = "metadata_keys")
60+
public List<String> metadataKeys;
61+
62+
@Column(name = "instructions")
63+
public String instructions;
64+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
/*
2+
* Copyright 2025 Red Hat, Inc. and/or its affiliates
3+
* and other contributors as indicated by the @author tags.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
*
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package com.redhat.exhort.model.modelcards;
20+
21+
public enum GuardrailScope {
22+
INPUT,
23+
OUTPUT,
24+
BOTH;
25+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Copyright 2025 Red Hat, Inc. and/or its affiliates
3+
* and other contributors as indicated by the @author tags.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
*
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package com.redhat.exhort.modelcards;
20+
21+
import java.util.List;
22+
23+
import com.redhat.exhort.model.modelcards.Guardrail;
24+
25+
import io.quarkus.hibernate.orm.panache.PanacheRepository;
26+
27+
import jakarta.enterprise.context.ApplicationScoped;
28+
29+
@ApplicationScoped
30+
public class GuardrailRepository implements PanacheRepository<Guardrail> {
31+
32+
public List<Guardrail> findByTaskMetricIds(List<Long> taskMetricIds) {
33+
return list(
34+
"SELECT DISTINCT g FROM Guardrail g JOIN g.metrics m WHERE m.id IN ?1", taskMetricIds);
35+
}
36+
}

src/main/java/com/redhat/exhort/modelcards/ModelCardService.java

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
package com.redhat.exhort.modelcards;
2020

21+
import java.util.ArrayList;
2122
import java.util.Collections;
2223
import java.util.List;
2324
import java.util.Map;
@@ -33,6 +34,7 @@
3334
import com.redhat.exhort.api.v4.ReportConfig;
3435
import com.redhat.exhort.api.v4.ReportMetric;
3536
import com.redhat.exhort.api.v4.ReportTask;
37+
import com.redhat.exhort.model.modelcards.Guardrail;
3638
import com.redhat.exhort.model.modelcards.ModelCardConfig;
3739
import com.redhat.exhort.model.modelcards.ModelCardReport;
3840
import com.redhat.exhort.model.modelcards.ModelCardTask;
@@ -48,13 +50,30 @@ public class ModelCardService {
4850

4951
@Inject ModelCardRepository repository;
5052

53+
@Inject GuardrailRepository guardrailRepository;
54+
5155
@Transactional
5256
public ModelCardResponse get(UUID id) {
5357
var report = repository.findById(id);
5458
if (report == null) {
5559
return null;
5660
}
57-
return toDto(report);
61+
List<Guardrail> guardrails = new ArrayList<>();
62+
if (report.tasks != null) {
63+
// Extract task metric IDs from the scores map
64+
List<Long> taskMetricIds =
65+
report.tasks.stream()
66+
.flatMap(task -> task.scores.keySet().stream())
67+
.map(metric -> metric.id)
68+
.distinct()
69+
.toList();
70+
71+
if (!taskMetricIds.isEmpty()) {
72+
guardrails = guardrailRepository.findByTaskMetricIds(taskMetricIds);
73+
}
74+
}
75+
76+
return toDto(report, guardrails);
5877
}
5978

6079
@Transactional
@@ -70,13 +89,17 @@ public List<ListModelCardResponse> find(List<ModelCardQueryItem> queries) {
7089
return reports.stream().map(this::toSummaryDto).collect(Collectors.toList());
7190
}
7291

73-
private ModelCardResponse toDto(ModelCardReport entity) {
92+
private ModelCardResponse toDto(ModelCardReport entity, List<Guardrail> guardrails) {
7493
var dto = new ModelCardResponse();
7594
dto.id(entity.id.toString());
7695
dto.name(entity.name);
7796
dto.source(entity.source);
7897
dto.config(toConfigDto(entity.config));
79-
dto.tasks(entity.tasks != null ? entity.tasks.stream().map(this::toTaskDto).toList() : null);
98+
dto.tasks(
99+
entity.tasks != null
100+
? entity.tasks.stream().map(t -> toTaskDto(t, guardrails)).toList()
101+
: null);
102+
dto.guardrails(guardrails.stream().map(this::toGuardrailDto).toList());
80103
return dto;
81104
}
82105

@@ -96,21 +119,23 @@ private ReportConfig toConfigDto(ModelCardConfig entity) {
96119
return dto;
97120
}
98121

99-
private ReportTask toTaskDto(ModelCardTask entity) {
122+
private ReportTask toTaskDto(ModelCardTask entity, List<Guardrail> guardrails) {
100123
if (entity == null) return null;
101124

102125
var dto = new ReportTask();
103126
dto.name(entity.task.name);
104127
dto.description(entity.task.description);
105128
dto.tags(List.copyOf(entity.task.tags));
106-
107129
dto.metrics(
108-
entity.scores.entrySet().stream().map(this::toMetricScoreDto).collect(Collectors.toList()));
130+
entity.scores.entrySet().stream()
131+
.map(s -> toMetricScoreDto(s, guardrails))
132+
.collect(Collectors.toList()));
109133

110134
return dto;
111135
}
112136

113-
private ReportMetric toMetricScoreDto(Map.Entry<TaskMetric, Float> e) {
137+
private ReportMetric toMetricScoreDto(
138+
Map.Entry<TaskMetric, Float> e, List<Guardrail> guardrails) {
114139
var score = new ReportMetric();
115140
score.name(e.getKey().name);
116141
score.score(e.getValue());
@@ -119,6 +144,13 @@ private ReportMetric toMetricScoreDto(Map.Entry<TaskMetric, Float> e) {
119144
e.getKey().thresholds != null
120145
? e.getKey().thresholds.stream().map(this::toThresholdDto).toList()
121146
: null);
147+
var metricGuardrails =
148+
guardrails.stream()
149+
.filter(g -> g.metrics.contains(e.getKey()))
150+
.map(g -> g.id)
151+
.sorted()
152+
.toList();
153+
score.guardrails(metricGuardrails);
122154
score.categories(List.copyOf(e.getKey().categories));
123155
return score;
124156
}
@@ -168,4 +200,16 @@ private String getAssessment(List<Threshold> thresholds, Float score) {
168200
}
169201
return null;
170202
}
203+
204+
private com.redhat.exhort.api.v4.Guardrail toGuardrailDto(Guardrail entity) {
205+
var dto = new com.redhat.exhort.api.v4.Guardrail();
206+
dto.id(entity.id);
207+
dto.name(entity.name);
208+
dto.description(entity.description);
209+
dto.scope(com.redhat.exhort.api.v4.Guardrail.ScopeEnum.valueOf(entity.scope.name()));
210+
dto.externalReferences(entity.references);
211+
dto.metadataKeys(entity.metadataKeys);
212+
dto.instructions(entity.instructions);
213+
return dto;
214+
}
171215
}

src/main/resources/db/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ This directory contains SQL scripts for setting up the Model Card database schem
77
- `V1__create_model_card_tables.sql` - Creates the database tables for Model Card entities
88
- `V2__insert_base_data.sql` - Inserts initial configuration data for Tasks and Thresholds
99
- `V3__insert_report_data.sql` - Inserts data from the available reports existing at the moment
10+
- `V4__create_guardrail_tables.sql` - Creates the database tables for Guardrail entities
11+
- `V5__inert_guardrail_data.sql` - Inserts Guardrail data for 4 initial recommendations
1012

1113
## Table Structure
1214

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
-- Create tables for Guardrail entity
2+
3+
-- Guardrail table
4+
CREATE TABLE guardrail (
5+
id BIGSERIAL PRIMARY KEY,
6+
name VARCHAR(255),
7+
description TEXT,
8+
scope VARCHAR(255),
9+
external_references VARCHAR(255)[],
10+
metadata_keys VARCHAR(255)[],
11+
instructions TEXT
12+
);
13+
14+
-- Guardrail Metrics join table
15+
CREATE TABLE guardrail_metrics (
16+
guardrail_id BIGINT NOT NULL,
17+
task_metric_id BIGINT NOT NULL,
18+
PRIMARY KEY (guardrail_id, task_metric_id),
19+
FOREIGN KEY (guardrail_id) REFERENCES guardrail(id) ON DELETE CASCADE,
20+
FOREIGN KEY (task_metric_id) REFERENCES task_metric(id) ON DELETE CASCADE
21+
);
22+
23+
-- Create explicit sequence to match Hibernate's naming expectations
24+
CREATE SEQUENCE guardrail_SEQ START WITH 1 INCREMENT BY 50;
25+
26+
-- Update table to use explicit sequence
27+
ALTER TABLE guardrail ALTER COLUMN id SET DEFAULT nextval('guardrail_SEQ');
28+
29+
-- Create indexes for better performance
30+
CREATE INDEX idx_guardrail_name ON guardrail(name);
31+
CREATE INDEX idx_guardrail_scope ON guardrail(scope);
32+
CREATE INDEX idx_guardrail_metrics_guardrail_id ON guardrail_metrics(guardrail_id);
33+
CREATE INDEX idx_guardrail_metrics_task_metric_id ON guardrail_metrics(task_metric_id);

0 commit comments

Comments
 (0)