Martín Santillán Cooper commited on
Commit
d46878a
1 Parent(s): 70e9d4b

start using local granite guardian model

Browse files
Files changed (5) hide show
  1. .env.example +2 -0
  2. app.py +25 -36
  3. logger.py +7 -0
  4. model.py +53 -0
  5. generate.py → utils.py +8 -10
.env.example ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ MODEL_PATH='../dmf_models/granite-guardian-8b-pipecleaner-r241024a'
2
+ USE_CONDA='true'
app.py CHANGED
@@ -1,9 +1,13 @@
1
  import gradio as gr
2
  from dotenv import load_dotenv
 
 
 
3
  import json
4
- from generate import generate_text, get_prompt_from_test_case
5
- from time import sleep
6
- # load_dotenv()
 
7
 
8
  catalog = {}
9
  all_test_cases = []
@@ -16,7 +20,7 @@ test_case_name = gr.HTML(f'<h2>{starting_test_case["name"]}</h2>')
16
  criteria = gr.Textbox(label="Definition", lines=3, interactive=False, value=starting_test_case['criteria'])
17
  context = gr.Textbox(label="Context", lines=3, interactive=True, value=starting_test_case['context'], visible=False)
18
  user_message = gr.Textbox(label="User Message", lines=3, interactive=True, value=starting_test_case['user_message'])
19
- assistant_message = gr.Textbox(label="Assistant Message", lines=3, interactive=True, value=starting_test_case['assistant_message'])
20
  catalog_buttons: dict[str,dict[str,gr.Button]] = {}
21
  result_text = gr.Textbox(label="Result", interactive=False)
22
  result_certainty = gr.Number(label="Certainty", interactive=False, value='')
@@ -31,32 +35,16 @@ for sub_catalog_name, sub_catalog in catalog.items():
31
  catalog_buttons[sub_catalog_name][test_case['name']] = \
32
  gr.Button(test_case['name'], elem_classes=elem_classes, variant='secondary', size='sm', elem_id=elem_id)
33
 
34
- # watsonx_api_url = os.getenv("WATSONX_URL", None)
35
- # watsonx_project_id = os.getenv("WATSONX_API_KEY", None)
36
- # watsonx_api_key = os.getenv("WATSONX_PROJECT_ID", None)
37
-
38
- # client = APIClient(credentials={
39
- # "url": watsonx_api_url,
40
- # "project_id": watsonx_project_id,
41
- # "api_key": watsonx_api_key
42
- # })
43
- # model = ModelInference(model_id=ModelTypes.LLAMA_3_8B_INSTRUCT, api_client=client)
44
-
45
- # client.set.default_project(watsonx_project_id)
46
-
47
- # bam_api_url = os.getenv("BAM_URL", None)
48
- # bam_api_key = os.getenv("BAM_API_KEY", None)
49
- # client = Client(credentials=Credentials(api_endpoint=bam_api_url, api_key=bam_api_key ))
50
-
51
-
52
- def on_test_case_click(link):
53
- selected_test_case = [t for sub_catalog in catalog.values() for t in sub_catalog if t['name'] == link][0]
54
  return {
55
  test_case_name: f'<h2>{selected_test_case["name"]}</h2>',
56
  criteria: selected_test_case['criteria'],
57
- context: selected_test_case['context'] if selected_test_case['context'] is not None else gr.update(visible=False),
58
  user_message: selected_test_case['user_message'],
59
- assistant_message: selected_test_case['assistant_message'],
60
  result_text: gr.update(value=''),
61
  result_certainty: gr.update(value='')
62
  }
@@ -65,18 +53,19 @@ def change_button_color(event: gr.EventData):
65
  return [gr.update(elem_classes=['catalog-button', 'selected']) if v.elem_id == event.target.elem_id else gr.update(elem_classes=['catalog-button']) for c in catalog_buttons.values() for v in c.values()]
66
 
67
  def on_submit(inputs):
68
- # prompt = get_prompt_from_test_case({
69
- # 'criteria': inputs[criteria],
70
- # 'context': inputs[context],
71
- # 'user_message': inputs[user_message],
72
- # 'assistant_message': inputs[assistant_message],
73
- # })
74
- # result = generate_text(prompt)
75
- # return result['assessment'], result['certainty']
76
- sleep(3)
77
- return 'Yes', 0.97
78
 
79
  with gr.Blocks(
 
80
  theme=gr.themes.Soft(font=[gr.themes.GoogleFont("IBM Plex Sans")]), css='styles.css') as demo:
81
  with gr.Row():
82
  gr.HTML('<h1>Granite Guardian</h1>', elem_classes='title')
 
1
  import gradio as gr
2
  from dotenv import load_dotenv
3
+
4
+ from utils import get_prompt_from_test_case
5
+ load_dotenv()
6
  import json
7
+ from model import generate_text
8
+ import logging
9
+
10
+ logging.getLogger('demo')
11
 
12
  catalog = {}
13
  all_test_cases = []
 
20
  criteria = gr.Textbox(label="Definition", lines=3, interactive=False, value=starting_test_case['criteria'])
21
  context = gr.Textbox(label="Context", lines=3, interactive=True, value=starting_test_case['context'], visible=False)
22
  user_message = gr.Textbox(label="User Message", lines=3, interactive=True, value=starting_test_case['user_message'])
23
+ assistant_message = gr.Textbox(label="Assistant Message", lines=3, interactive=True, visible=False, value=starting_test_case['assistant_message'])
24
  catalog_buttons: dict[str,dict[str,gr.Button]] = {}
25
  result_text = gr.Textbox(label="Result", interactive=False)
26
  result_certainty = gr.Number(label="Certainty", interactive=False, value='')
 
35
  catalog_buttons[sub_catalog_name][test_case['name']] = \
36
  gr.Button(test_case['name'], elem_classes=elem_classes, variant='secondary', size='sm', elem_id=elem_id)
37
 
38
+ def on_test_case_click(link, event: gr.EventData):
39
+ target_sub_catalog_name, target_test_case_name = event.target.elem_id.split('_')
40
+ selected_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == link and sub_catalog_name == target_sub_catalog_name][0]
41
+ print(selected_test_case['assistant_message'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  return {
43
  test_case_name: f'<h2>{selected_test_case["name"]}</h2>',
44
  criteria: selected_test_case['criteria'],
45
+ context: selected_test_case['context'] if selected_test_case['context'] is not None else gr.update(visible=False, value=''),
46
  user_message: selected_test_case['user_message'],
47
+ assistant_message: gr.update(value=selected_test_case['assistant_message'], visible=True) if selected_test_case['assistant_message'] is not None else gr.update(visible=False, value=''),
48
  result_text: gr.update(value=''),
49
  result_certainty: gr.update(value='')
50
  }
 
53
  return [gr.update(elem_classes=['catalog-button', 'selected']) if v.elem_id == event.target.elem_id else gr.update(elem_classes=['catalog-button']) for c in catalog_buttons.values() for v in c.values()]
54
 
55
  def on_submit(inputs):
56
+ prompt = get_prompt_from_test_case({
57
+ 'criteria': inputs[criteria],
58
+ 'context': inputs[context],
59
+ 'user_message': inputs[user_message],
60
+ 'assistant_message': inputs[assistant_message],
61
+ })
62
+ result = generate_text(prompt)
63
+ return result['assessment'], result['certainty']
64
+ # sleep(3)
65
+ # return 'Yes', 0.97
66
 
67
  with gr.Blocks(
68
+ title='Granite Guardian',
69
  theme=gr.themes.Soft(font=[gr.themes.GoogleFont("IBM Plex Sans")]), css='styles.css') as demo:
70
  with gr.Row():
71
  gr.HTML('<h1>Granite Guardian</h1>', elem_classes='title')
logger.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ logger = logging.getLogger('demo')
4
+ logger.setLevel(logging.DEBUG)
5
+ handler = logging.StreamHandler()
6
+ handler.setLevel(logging.DEBUG)
7
+ logger.addHandler(handler)
model.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging.handlers
2
+ import torch
3
+ from torch.nn.functional import softmax
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
5
+ import jinja2
6
+ import os
7
+ from time import time
8
+ from logger import logger
9
+
10
+ use_conda = os.getenv('USE_CONDA', "false") == "true"
11
+ device = "cuda"
12
+ model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
13
+ logger.info(f'Model path is "{model_path}"')
14
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
15
+ model = AutoModelForCausalLM.from_pretrained(
16
+ model_path,
17
+ device_map=device if use_conda else None
18
+ )
19
+
20
+
21
+ def generate_text(prompt):
22
+ logger.debug('Starting evaluation...')
23
+ logger.debug(f'Prompts content is: \n{prompt["content"]}')
24
+ start = time()
25
+ tokenized_chat = tokenizer.apply_chat_template(
26
+ [prompt],
27
+ tokenize=True,
28
+ add_generation_prompt=True,
29
+ return_tensors="pt")#.to(device)
30
+ if use_conda:
31
+ tokenized_chat.to(device)
32
+ with torch.no_grad():
33
+ logits = model(tokenized_chat).logits
34
+ gen_outputs = model.generate(tokenized_chat, max_new_tokens=128)
35
+
36
+ generated_text = tokenizer.decode(gen_outputs[0])
37
+ logger.debug(f'Model generated text: \n{generated_text}')
38
+ vocab = tokenizer.get_vocab()
39
+ selected_logits = logits[0, -1, [vocab['No'], vocab['Yes']]]
40
+ probabilities = softmax(selected_logits, dim=0)
41
+
42
+ prob = probabilities[1].item()
43
+ logger.debug(f'Certainty is: {prob} from probabilities {probabilities}')
44
+ certainty = prob
45
+ assessment = 'Yes' if certainty > 0.5 else 'No'
46
+ certainty = 1 - certainty if certainty < 0.5 else certainty
47
+ certainty = f'{round(certainty,3)}'
48
+
49
+ end = time()
50
+ total = end - start
51
+ logger.debug(f'it took {round(total/60, 2)} mins')
52
+
53
+ return {'assessment': assessment, 'certainty': certainty}
generate.py → utils.py RENAMED
@@ -25,18 +25,16 @@ def turn_section_content(test_case):
25
  result = ''
26
  if test_case['context'] != '':
27
  result += 'Context: ' + test_case['context'] + '\n'
28
- result += 'User message: ' + test_case['user_message'] #+ '\n' + 'Assistant message: ' + test_case['assistant_message']
 
 
 
 
29
  return result
30
 
31
  def get_prompt_from_test_case(test_case):
32
- return json.dumps(assessment_prompt(assessment_prompt_content().format(
 
33
  turn_section_content=turn_section_content(test_case),
34
  criteria=test_case['criteria']
35
- )))
36
-
37
- def generate_text(prompt):
38
- result = requests.post('http://localhost:8081/generate', json={'input': prompt}).json()
39
- assessment = 'Yes' if result['certainty'] > 0.5 else 'No'
40
- certainty = 1 - result['certainty'] if result['certainty'] < 0.5 else result['certainty']
41
- certainty = f'{round(certainty,3)}'
42
- return {'assessment': assessment, 'certainty': certainty}
 
25
  result = ''
26
  if test_case['context'] != '':
27
  result += 'Context: ' + test_case['context'] + '\n'
28
+
29
+ result += 'User message: ' + test_case['user_message']
30
+
31
+ if test_case['assistant_message'] != '':
32
+ result += '\n\nAssistant message: ' + test_case['assistant_message'] + '\n'
33
  return result
34
 
35
  def get_prompt_from_test_case(test_case):
36
+ print(json.dumps(test_case, indent=4))
37
+ return assessment_prompt(assessment_prompt_content().format(
38
  turn_section_content=turn_section_content(test_case),
39
  criteria=test_case['criteria']
40
+ ))