Martín Santillán Cooper commited on
Commit
f97dae7
·
1 Parent(s): f492568

prepare for openshift deployment

Browse files
.dockerignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ .*
3
+ *.yml
4
+ *.yaml
5
+ *.sh
6
+ *.md
7
+ __pycache__/
8
+ flagged/
.env.example CHANGED
@@ -1,3 +1,5 @@
1
  MODEL_PATH='../dmf_models/granite-guardian-8b-pipecleaner-r241024a'
2
  USE_CONDA='true'
3
- MOCK_MODEL_CALL='false'
 
 
 
1
  MODEL_PATH='../dmf_models/granite-guardian-8b-pipecleaner-r241024a'
2
  USE_CONDA='true'
3
+ INFERENCE_ENGINE='' # one of [WATSONX, MOCK, VLLM]
4
+ WATSONX_API_KEY=""
5
+ WATSONX_PROJECT_ID=""
.gitignore CHANGED
@@ -2,5 +2,6 @@
2
  .env
3
  parse.py
4
  unparsed_catalog.json
5
- __pycache__
6
- logs
 
 
2
  .env
3
  parse.py
4
  unparsed_catalog.json
5
+ __pycache__/
6
+ logs.txt
7
+ secrets.yaml
Dockerfile ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+ WORKDIR /usr/src/app
3
+ COPY . .
4
+ RUN pip --disable-pip-version-check --no-cache-dir --no-input install -r requirements.txt
5
+ ENV GRADIO_SERVER_NAME="0.0.0.0"
6
+ EXPOSE 7860
7
+ CMD ["python", "src/app.py"]
cicd/build.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ docker build --platform=linux/amd64 . -t granite-guardian
2
+ docker tag granite-guardian us.icr.io/research3/granite-guardian
cicd/deploy.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ibmcloud cr login
2
+ oc delete -f deployment.yaml
3
+ oc apply -f deployment.yaml
cicd/push_image.sh ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ibmcloud target -g aipt-experiments
2
+ docker push us.icr.io/research3/granite-guardian
cicd/run.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ./build.sh
2
+ ./push_image.sh
3
+ ./deploy.sh
deployment.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ apiVersion: apps/v1
2
+ kind: Deployment
3
+ metadata:
4
+ name: granite-guardian-pod
5
+ labels:
6
+ app: granite-guardian
7
+ spec:
8
+ selector:
9
+ matchLabels:
10
+ run: granite-guardian
11
+ replicas: 1
12
+ template:
13
+ metadata:
14
+ labels:
15
+ run: granite-guardian
16
+ spec:
17
+ containers:
18
+ - name: granite-guardian
19
+ image: us.icr.io/research3/granite-guardian
20
+ resources:
21
+ limits:
22
+ cpu: 1
23
+ memory: 2Gi
24
+ requests:
25
+ cpu: 1
26
+ memory: 2Gi
27
+ ports:
28
+ - containerPort: 7860
29
+ env:
30
+ - name: WATSONX_API_KEY
31
+ valueFrom:
32
+ secretKeyRef:
33
+ name: granite-guardian-secrets
34
+ key: WATSONX_API_KEY
35
+ - name: WATSONX_PROJECT_ID
36
+ valueFrom:
37
+ secretKeyRef:
38
+ name: granite-guardian-secrets
39
+ key: WATSONX_PROJECT_ID
40
+ - name: INFERENCE_ENGINE
41
+ valueFrom:
42
+ secretKeyRef:
43
+ name: granite-guardian-secrets
44
+ key: INFERENCE_ENGINE
45
+ imagePullSecrets:
46
+ - name: all-icr-io
47
+ ---
48
+ apiVersion: v1
49
+ kind: Service
50
+ metadata:
51
+ name: granite-guardian-service
52
+ spec:
53
+ type: NodePort
54
+ sessionAffinity: "ClientIP"
55
+ selector:
56
+ run: granite-guardian
57
+ ports:
58
+ - port: 80
59
+ targetPort: 7860
60
+ protocol: TCP
61
+ ---
62
+ apiVersion: networking.k8s.io/v1
63
+ kind: Ingress
64
+ metadata:
65
+ annotations:
66
+ ingress.kubernetes.io/allow-http: 'false'
67
+ ingress.kubernetes.io/ssl-redirect: 'true'
68
+ kubernetes.io/ingress.class: f5
69
+ virtual-server.f5.com/balance: round-robin
70
+ virtual-server.f5.com/ip: 9.12.246.36
71
+ virtual-server.f5.com/partition: RIS3-INT-OCP-DAL12
72
+ virtual-server.f5.com/clientssl: '[ { "bigIpProfile": "/Common/BlueMix" } ]'
73
+ name: granite-guardian-ingress
74
+ namespace: granite-guardian
75
+ spec:
76
+ rules:
77
+ - host: granite-guardian.bx.cloud9.ibm.com
78
+ http:
79
+ paths:
80
+ - backend:
81
+ service:
82
+ name: granite-guardian-service
83
+ port:
84
+ number: 80
85
+ path: /
86
+ pathType: ImplementationSpecific
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
- gradio
2
  python-dotenv
3
  tqdm
4
  jinja2
 
 
 
 
1
+ gradio>=4,<5
2
  python-dotenv
3
  tqdm
4
  jinja2
5
+ ibm_watsonx_ai
6
+ transformers
7
+ gradio_modal
run_cicd.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ./cicd/build.sh
2
+ ./cicd/push_image.sh
3
+ ./cicd/deploy.sh
src/app.py CHANGED
@@ -112,6 +112,7 @@ def on_show_prompt_click(criteria, context, user_message, assistant_message, sta
112
 
113
  messages = get_messages(test_case=test_case, sub_catalog_name=state['selected_sub_catalog'])
114
  prompt = get_prompt(messages, criteria_name)
 
115
  prompt = prompt.replace('<', '&lt;').replace('>', '&gt;').replace('\\n', '<br>')
116
  return gr.Markdown(prompt)
117
 
@@ -155,7 +156,7 @@ with gr.Blocks(
155
  ),
156
  head=head_style,
157
  fill_width=False,
158
- css=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'styles.css')
159
  ) as demo:
160
 
161
  state = gr.State(value={
 
112
 
113
  messages = get_messages(test_case=test_case, sub_catalog_name=state['selected_sub_catalog'])
114
  prompt = get_prompt(messages, criteria_name)
115
+ print(prompt)
116
  prompt = prompt.replace('<', '&lt;').replace('>', '&gt;').replace('\\n', '<br>')
117
  return gr.Markdown(prompt)
118
 
 
156
  ),
157
  head=head_style,
158
  fill_width=False,
159
+ css=os.path.join(os.path.dirname(os.path.abspath(__file__)), './styles.css')
160
  ) as demo:
161
 
162
  state = gr.State(value={
src/logger.py CHANGED
@@ -7,6 +7,6 @@ stream_handler = logging.StreamHandler()
7
  stream_handler.setLevel(logging.DEBUG)
8
  logger.addHandler(stream_handler)
9
 
10
- file_handler = logging.FileHandler('logs')
11
  file_handler.setFormatter(logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
12
  logger.addHandler(file_handler)
 
7
  stream_handler.setLevel(logging.DEBUG)
8
  logger.addHandler(stream_handler)
9
 
10
+ file_handler = logging.FileHandler('logs.txt')
11
  file_handler.setFormatter(logging.Formatter("%(asctime)s - %(filename)s:%(lineno)d - %(message)s", datefmt="%Y-%m-%d %H:%M:%S"))
12
  logger.addHandler(file_handler)
src/model.py CHANGED
@@ -2,13 +2,20 @@ import os
2
  from time import time, sleep
3
  from logger import logger
4
  import math
 
 
 
 
 
5
 
6
  safe_token = "No"
7
- unsafe_token = "Yes"
8
  nlogprobs = 5
9
 
10
- mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
11
- if not mock_model_call:
 
 
12
  import torch
13
  from vllm import LLM, SamplingParams
14
  from transformers import AutoTokenizer
@@ -18,6 +25,21 @@ if not mock_model_call:
18
  sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
19
  model = LLM(model=model_path, tensor_parallel_size=1)
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def parse_output(output):
22
  label, prob = None, None
23
 
@@ -28,8 +50,8 @@ def parse_output(output):
28
  prob_of_risk = prob[1]
29
 
30
  res = next(iter(output.outputs)).text.strip()
31
- if unsafe_token.lower() == res.lower():
32
- label = unsafe_token
33
  elif safe_token.lower() == res.lower():
34
  label = safe_token
35
  else:
@@ -37,6 +59,11 @@ def parse_output(output):
37
 
38
  return label, prob_of_risk.item()
39
 
 
 
 
 
 
40
  def get_probablities(logprobs):
41
  safe_token_prob = 1e-50
42
  unsafe_token_prob = 1e-50
@@ -45,7 +72,7 @@ def get_probablities(logprobs):
45
  decoded_token = token_prob.decoded_token
46
  if decoded_token.strip().lower() == safe_token.lower():
47
  safe_token_prob += math.exp(token_prob.logprob)
48
- if decoded_token.strip().lower() == unsafe_token.lower():
49
  unsafe_token_prob += math.exp(token_prob.logprob)
50
 
51
  probabilities = torch.softmax(
@@ -54,6 +81,20 @@ def get_probablities(logprobs):
54
 
55
  return probabilities
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  def get_prompt(messages, criteria_name):
58
  guardian_config = {"risk_name": criteria_name if criteria_name != 'general_harm' else 'harm'}
59
  return tokenizer.apply_chat_template(
@@ -62,26 +103,65 @@ def get_prompt(messages, criteria_name):
62
  tokenize=False,
63
  add_generation_prompt=True)
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  def generate_text(messages, criteria_name):
67
- logger.debug(f'Messages are: \n{messages}')
68
-
69
- mock_model_call = os.getenv('MOCK_MODEL_CALL') == 'true'
70
- if mock_model_call:
71
- logger.debug('Returning mocked model result.')
72
- sleep(1)
73
- return {'assessment': 'Yes', 'certainty': 0.97}
74
 
75
  start = time()
 
76
  chat = get_prompt(messages, criteria_name)
77
  logger.debug(f'Prompt is \n{chat}')
 
 
 
 
 
78
 
79
- with torch.no_grad():
80
- output = model.generate(chat, sampling_params, use_tqdm=False)
 
81
 
82
- # predicted_label = output[0].outputs[0].text.strip()
 
 
83
 
84
- label, prob_of_risk = parse_output(output[0])
 
 
85
 
86
  logger.debug(f'Model generated label: \n{label}')
87
  logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
 
2
  from time import time, sleep
3
  from logger import logger
4
  import math
5
+ import os
6
+ from ibm_watsonx_ai.client import APIClient
7
+ from ibm_watsonx_ai.foundation_models import ModelInference
8
+ from transformers import AutoTokenizer
9
+ import math
10
 
11
  safe_token = "No"
12
+ risky_token = "Yes"
13
  nlogprobs = 5
14
 
15
+ inference_engine = os.getenv('INFERENCE_ENGINE')
16
+ logger.debug(f"Inference engine is: '{inference_engine}'")
17
+
18
+ if inference_engine == 'VLLM':
19
  import torch
20
  from vllm import LLM, SamplingParams
21
  from transformers import AutoTokenizer
 
25
  sampling_params = SamplingParams(temperature=0.0, logprobs=nlogprobs)
26
  model = LLM(model=model_path, tensor_parallel_size=1)
27
 
28
+ elif inference_engine == "WATSONX":
29
+ client = APIClient(credentials={
30
+ 'api_key': os.getenv('WATSONX_API_KEY'),
31
+ 'url': 'https://us-south.ml.cloud.ibm.com'})
32
+
33
+ client.set.default_project(os.getenv('WATSONX_PROJECT_ID'))
34
+ hf_model_path = "ibm-granite/granite-guardian-3.0-8b"
35
+ tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
36
+
37
+ model_id = "ibm/granite-guardian-3-8b" # 8B Model: "ibm/granite-guardian-3-8b"
38
+ model = ModelInference(
39
+ model_id=model_id,
40
+ api_client=client
41
+ )
42
+
43
  def parse_output(output):
44
  label, prob = None, None
45
 
 
50
  prob_of_risk = prob[1]
51
 
52
  res = next(iter(output.outputs)).text.strip()
53
+ if risky_token.lower() == res.lower():
54
+ label = risky_token
55
  elif safe_token.lower() == res.lower():
56
  label = safe_token
57
  else:
 
59
 
60
  return label, prob_of_risk.item()
61
 
62
+ def softmax(values):
63
+ exp_values = [math.exp(v) for v in values]
64
+ total = sum(exp_values)
65
+ return [v / total for v in exp_values]
66
+
67
  def get_probablities(logprobs):
68
  safe_token_prob = 1e-50
69
  unsafe_token_prob = 1e-50
 
72
  decoded_token = token_prob.decoded_token
73
  if decoded_token.strip().lower() == safe_token.lower():
74
  safe_token_prob += math.exp(token_prob.logprob)
75
+ if decoded_token.strip().lower() == risky_token.lower():
76
  unsafe_token_prob += math.exp(token_prob.logprob)
77
 
78
  probabilities = torch.softmax(
 
81
 
82
  return probabilities
83
 
84
+ def get_probablities_watsonx(top_tokens_list):
85
+ safe_token_prob = 1e-50
86
+ risky_token_prob = 1e-50
87
+ for top_tokens in top_tokens_list:
88
+ for token in top_tokens:
89
+ if token['text'].strip().lower() == safe_token.lower():
90
+ safe_token_prob += math.exp(token['logprob'])
91
+ if token['text'].strip().lower() == risky_token.lower():
92
+ risky_token_prob += math.exp(token['logprob'])
93
+
94
+ probabilities = softmax([math.log(safe_token_prob), math.log(risky_token_prob)])
95
+
96
+ return probabilities
97
+
98
  def get_prompt(messages, criteria_name):
99
  guardian_config = {"risk_name": criteria_name if criteria_name != 'general_harm' else 'harm'}
100
  return tokenizer.apply_chat_template(
 
103
  tokenize=False,
104
  add_generation_prompt=True)
105
 
106
+ def generate_tokens(prompt):
107
+ result = model.generate(
108
+ prompt=[prompt],
109
+ params={
110
+ 'decoding_method':'greedy',
111
+ 'max_new_tokens': 20,
112
+ "temperature": 0,
113
+ "return_options": {
114
+ "token_logprobs": True,
115
+ "generated_tokens": True,
116
+ "input_text": True,
117
+ "top_n_tokens": 5
118
+ }
119
+ })
120
+ return result[0]['results'][0]['generated_tokens']
121
+
122
+ def parse_output_watsonx(generated_tokens_list):
123
+ label, prob_of_risk = None, None
124
+
125
+ if nlogprobs > 0:
126
+ top_tokens_list = [generated_tokens['top_tokens'] for generated_tokens in generated_tokens_list]
127
+ prob = get_probablities_watsonx(top_tokens_list)
128
+ prob_of_risk = prob[1]
129
+
130
+ res = next(iter(generated_tokens_list))['text'].strip()
131
+
132
+ if risky_token.lower() == res.lower():
133
+ label = risky_token
134
+ elif safe_token.lower() == res.lower():
135
+ label = safe_token
136
+ else:
137
+ label = "Failed"
138
+
139
+ return label, prob_of_risk
140
 
141
  def generate_text(messages, criteria_name):
142
+ logger.debug(f'Messages used to create the prompt are: \n{messages}')
 
 
 
 
 
 
143
 
144
  start = time()
145
+
146
  chat = get_prompt(messages, criteria_name)
147
  logger.debug(f'Prompt is \n{chat}')
148
+
149
+ if inference_engine=="MOCK":
150
+ logger.debug('Returning mocked model result.')
151
+ sleep(1)
152
+ label, prob_of_risk = 'Yes', 0.97
153
 
154
+ elif inference_engine=="WATSONX":
155
+ generated_tokens = generate_tokens(chat)
156
+ label, prob_of_risk = parse_output_watsonx(generated_tokens)
157
 
158
+ elif inference_engine=="VLLM":
159
+ with torch.no_grad():
160
+ output = model.generate(chat, sampling_params, use_tqdm=False)
161
 
162
+ label, prob_of_risk = parse_output(output[0])
163
+ else:
164
+ raise Exception("Environment variable 'INFERENCE_ENGINE' must be one of [WATSONX, MOCK, VLLM]")
165
 
166
  logger.debug(f'Model generated label: \n{label}')
167
  logger.debug(f'Model prob_of_risk: \n{prob_of_risk}')
src/utils.py CHANGED
@@ -1,27 +1,6 @@
1
- import json
2
- from jinja2 import Template
3
  import argparse
4
  import os
5
 
6
- # with open('prompt_templates.json', mode='r', encoding="utf-8") as f:
7
- # prompt_templates = json.load(f)
8
-
9
- # def assessment_prompt(content):
10
- # return {"role": "user", "content": content}
11
-
12
- # def get_prompt_template(test_case, sub_catalog_name):
13
- # test_case_name = test_case['name']
14
- # if sub_catalog_name == 'harmful_content_in_user_prompt':
15
- # template_type = 'prompt'
16
- # elif sub_catalog_name == 'harmful_content_in_assistant_response':
17
- # template_type = 'prompt_response'
18
- # elif sub_catalog_name == 'rag_hallucination_risks':
19
- # template_type = test_case_name
20
- # return prompt_templates[f'{test_case_name}>{template_type}']
21
-
22
- # def get_prompt_from_test_case(test_case, sub_catalog_name):
23
- # return assessment_prompt(Template(get_prompt_template(test_case, sub_catalog_name)).render(**test_case))
24
-
25
  def get_messages(test_case, sub_catalog_name) -> list[dict[str,str]]:
26
  messages = []
27
 
@@ -76,14 +55,16 @@ def get_evaluated_component(sub_catalog_name, criteria_name):
76
  return component
77
 
78
  def to_title_case(input_string):
79
- if input_string == 'rag_hallucination_risks':
80
  return 'RAG Hallucination Risks'
81
  return ' '.join(word.capitalize() for word in input_string.split('_'))
82
 
 
 
 
83
  def to_snake_case(text):
84
  return text.lower().replace(" ", "_")
85
 
86
-
87
  def load_command_line_args():
88
  parser = argparse.ArgumentParser()
89
  parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")
 
 
 
1
  import argparse
2
  import os
3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def get_messages(test_case, sub_catalog_name) -> list[dict[str,str]]:
5
  messages = []
6
 
 
55
  return component
56
 
57
  def to_title_case(input_string):
58
+ if input_string == 'rag_hallucination_risks':
59
  return 'RAG Hallucination Risks'
60
  return ' '.join(word.capitalize() for word in input_string.split('_'))
61
 
62
+ def capitalize_first_word(input_string):
63
+ return ' '.join(word.capitalize() if i == 0 else word for i, word in enumerate(input_string.split('_')))
64
+
65
  def to_snake_case(text):
66
  return text.lower().replace(" ", "_")
67
 
 
68
  def load_command_line_args():
69
  parser = argparse.ArgumentParser()
70
  parser.add_argument("--model_path", type=str, default=None, help="Path to the model or HF repo")