Spaces:
Running
on
Zero
Running
on
Zero
Martín Santillán Cooper
commited on
Commit
•
d46878a
1
Parent(s):
70e9d4b
start using local granite guardian model
Browse files- .env.example +2 -0
- app.py +25 -36
- logger.py +7 -0
- model.py +53 -0
- generate.py → utils.py +8 -10
.env.example
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
MODEL_PATH='../dmf_models/granite-guardian-8b-pipecleaner-r241024a'
|
2 |
+
USE_CONDA='true'
|
app.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
import gradio as gr
|
2 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
3 |
import json
|
4 |
-
from
|
5 |
-
|
6 |
-
|
|
|
7 |
|
8 |
catalog = {}
|
9 |
all_test_cases = []
|
@@ -16,7 +20,7 @@ test_case_name = gr.HTML(f'<h2>{starting_test_case["name"]}</h2>')
|
|
16 |
criteria = gr.Textbox(label="Definition", lines=3, interactive=False, value=starting_test_case['criteria'])
|
17 |
context = gr.Textbox(label="Context", lines=3, interactive=True, value=starting_test_case['context'], visible=False)
|
18 |
user_message = gr.Textbox(label="User Message", lines=3, interactive=True, value=starting_test_case['user_message'])
|
19 |
-
assistant_message = gr.Textbox(label="Assistant Message", lines=3, interactive=True, value=starting_test_case['assistant_message'])
|
20 |
catalog_buttons: dict[str,dict[str,gr.Button]] = {}
|
21 |
result_text = gr.Textbox(label="Result", interactive=False)
|
22 |
result_certainty = gr.Number(label="Certainty", interactive=False, value='')
|
@@ -31,32 +35,16 @@ for sub_catalog_name, sub_catalog in catalog.items():
|
|
31 |
catalog_buttons[sub_catalog_name][test_case['name']] = \
|
32 |
gr.Button(test_case['name'], elem_classes=elem_classes, variant='secondary', size='sm', elem_id=elem_id)
|
33 |
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
# client = APIClient(credentials={
|
39 |
-
# "url": watsonx_api_url,
|
40 |
-
# "project_id": watsonx_project_id,
|
41 |
-
# "api_key": watsonx_api_key
|
42 |
-
# })
|
43 |
-
# model = ModelInference(model_id=ModelTypes.LLAMA_3_8B_INSTRUCT, api_client=client)
|
44 |
-
|
45 |
-
# client.set.default_project(watsonx_project_id)
|
46 |
-
|
47 |
-
# bam_api_url = os.getenv("BAM_URL", None)
|
48 |
-
# bam_api_key = os.getenv("BAM_API_KEY", None)
|
49 |
-
# client = Client(credentials=Credentials(api_endpoint=bam_api_url, api_key=bam_api_key ))
|
50 |
-
|
51 |
-
|
52 |
-
def on_test_case_click(link):
|
53 |
-
selected_test_case = [t for sub_catalog in catalog.values() for t in sub_catalog if t['name'] == link][0]
|
54 |
return {
|
55 |
test_case_name: f'<h2>{selected_test_case["name"]}</h2>',
|
56 |
criteria: selected_test_case['criteria'],
|
57 |
-
context: selected_test_case['context'] if selected_test_case['context'] is not None else gr.update(visible=False),
|
58 |
user_message: selected_test_case['user_message'],
|
59 |
-
assistant_message: selected_test_case['assistant_message'],
|
60 |
result_text: gr.update(value=''),
|
61 |
result_certainty: gr.update(value='')
|
62 |
}
|
@@ -65,18 +53,19 @@ def change_button_color(event: gr.EventData):
|
|
65 |
return [gr.update(elem_classes=['catalog-button', 'selected']) if v.elem_id == event.target.elem_id else gr.update(elem_classes=['catalog-button']) for c in catalog_buttons.values() for v in c.values()]
|
66 |
|
67 |
def on_submit(inputs):
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
sleep(3)
|
77 |
-
return 'Yes', 0.97
|
78 |
|
79 |
with gr.Blocks(
|
|
|
80 |
theme=gr.themes.Soft(font=[gr.themes.GoogleFont("IBM Plex Sans")]), css='styles.css') as demo:
|
81 |
with gr.Row():
|
82 |
gr.HTML('<h1>Granite Guardian</h1>', elem_classes='title')
|
|
|
1 |
import gradio as gr
|
2 |
from dotenv import load_dotenv
|
3 |
+
|
4 |
+
from utils import get_prompt_from_test_case
|
5 |
+
load_dotenv()
|
6 |
import json
|
7 |
+
from model import generate_text
|
8 |
+
import logging
|
9 |
+
|
10 |
+
logging.getLogger('demo')
|
11 |
|
12 |
catalog = {}
|
13 |
all_test_cases = []
|
|
|
20 |
criteria = gr.Textbox(label="Definition", lines=3, interactive=False, value=starting_test_case['criteria'])
|
21 |
context = gr.Textbox(label="Context", lines=3, interactive=True, value=starting_test_case['context'], visible=False)
|
22 |
user_message = gr.Textbox(label="User Message", lines=3, interactive=True, value=starting_test_case['user_message'])
|
23 |
+
assistant_message = gr.Textbox(label="Assistant Message", lines=3, interactive=True, visible=False, value=starting_test_case['assistant_message'])
|
24 |
catalog_buttons: dict[str,dict[str,gr.Button]] = {}
|
25 |
result_text = gr.Textbox(label="Result", interactive=False)
|
26 |
result_certainty = gr.Number(label="Certainty", interactive=False, value='')
|
|
|
35 |
catalog_buttons[sub_catalog_name][test_case['name']] = \
|
36 |
gr.Button(test_case['name'], elem_classes=elem_classes, variant='secondary', size='sm', elem_id=elem_id)
|
37 |
|
38 |
+
def on_test_case_click(link, event: gr.EventData):
|
39 |
+
target_sub_catalog_name, target_test_case_name = event.target.elem_id.split('_')
|
40 |
+
selected_test_case = [t for sub_catalog_name, sub_catalog in catalog.items() for t in sub_catalog if t['name'] == link and sub_catalog_name == target_sub_catalog_name][0]
|
41 |
+
print(selected_test_case['assistant_message'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
return {
|
43 |
test_case_name: f'<h2>{selected_test_case["name"]}</h2>',
|
44 |
criteria: selected_test_case['criteria'],
|
45 |
+
context: selected_test_case['context'] if selected_test_case['context'] is not None else gr.update(visible=False, value=''),
|
46 |
user_message: selected_test_case['user_message'],
|
47 |
+
assistant_message: gr.update(value=selected_test_case['assistant_message'], visible=True) if selected_test_case['assistant_message'] is not None else gr.update(visible=False, value=''),
|
48 |
result_text: gr.update(value=''),
|
49 |
result_certainty: gr.update(value='')
|
50 |
}
|
|
|
53 |
return [gr.update(elem_classes=['catalog-button', 'selected']) if v.elem_id == event.target.elem_id else gr.update(elem_classes=['catalog-button']) for c in catalog_buttons.values() for v in c.values()]
|
54 |
|
55 |
def on_submit(inputs):
|
56 |
+
prompt = get_prompt_from_test_case({
|
57 |
+
'criteria': inputs[criteria],
|
58 |
+
'context': inputs[context],
|
59 |
+
'user_message': inputs[user_message],
|
60 |
+
'assistant_message': inputs[assistant_message],
|
61 |
+
})
|
62 |
+
result = generate_text(prompt)
|
63 |
+
return result['assessment'], result['certainty']
|
64 |
+
# sleep(3)
|
65 |
+
# return 'Yes', 0.97
|
66 |
|
67 |
with gr.Blocks(
|
68 |
+
title='Granite Guardian',
|
69 |
theme=gr.themes.Soft(font=[gr.themes.GoogleFont("IBM Plex Sans")]), css='styles.css') as demo:
|
70 |
with gr.Row():
|
71 |
gr.HTML('<h1>Granite Guardian</h1>', elem_classes='title')
|
logger.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
logger = logging.getLogger('demo')
|
4 |
+
logger.setLevel(logging.DEBUG)
|
5 |
+
handler = logging.StreamHandler()
|
6 |
+
handler.setLevel(logging.DEBUG)
|
7 |
+
logger.addHandler(handler)
|
model.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging.handlers
|
2 |
+
import torch
|
3 |
+
from torch.nn.functional import softmax
|
4 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
5 |
+
import jinja2
|
6 |
+
import os
|
7 |
+
from time import time
|
8 |
+
from logger import logger
|
9 |
+
|
10 |
+
use_conda = os.getenv('USE_CONDA', "false") == "true"
|
11 |
+
device = "cuda"
|
12 |
+
model_path = os.getenv('MODEL_PATH')#"granite-guardian-3b-pipecleaner-r241024a"
|
13 |
+
logger.info(f'Model path is "{model_path}"')
|
14 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
15 |
+
model = AutoModelForCausalLM.from_pretrained(
|
16 |
+
model_path,
|
17 |
+
device_map=device if use_conda else None
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def generate_text(prompt):
|
22 |
+
logger.debug('Starting evaluation...')
|
23 |
+
logger.debug(f'Prompts content is: \n{prompt["content"]}')
|
24 |
+
start = time()
|
25 |
+
tokenized_chat = tokenizer.apply_chat_template(
|
26 |
+
[prompt],
|
27 |
+
tokenize=True,
|
28 |
+
add_generation_prompt=True,
|
29 |
+
return_tensors="pt")#.to(device)
|
30 |
+
if use_conda:
|
31 |
+
tokenized_chat.to(device)
|
32 |
+
with torch.no_grad():
|
33 |
+
logits = model(tokenized_chat).logits
|
34 |
+
gen_outputs = model.generate(tokenized_chat, max_new_tokens=128)
|
35 |
+
|
36 |
+
generated_text = tokenizer.decode(gen_outputs[0])
|
37 |
+
logger.debug(f'Model generated text: \n{generated_text}')
|
38 |
+
vocab = tokenizer.get_vocab()
|
39 |
+
selected_logits = logits[0, -1, [vocab['No'], vocab['Yes']]]
|
40 |
+
probabilities = softmax(selected_logits, dim=0)
|
41 |
+
|
42 |
+
prob = probabilities[1].item()
|
43 |
+
logger.debug(f'Certainty is: {prob} from probabilities {probabilities}')
|
44 |
+
certainty = prob
|
45 |
+
assessment = 'Yes' if certainty > 0.5 else 'No'
|
46 |
+
certainty = 1 - certainty if certainty < 0.5 else certainty
|
47 |
+
certainty = f'{round(certainty,3)}'
|
48 |
+
|
49 |
+
end = time()
|
50 |
+
total = end - start
|
51 |
+
logger.debug(f'it took {round(total/60, 2)} mins')
|
52 |
+
|
53 |
+
return {'assessment': assessment, 'certainty': certainty}
|
generate.py → utils.py
RENAMED
@@ -25,18 +25,16 @@ def turn_section_content(test_case):
|
|
25 |
result = ''
|
26 |
if test_case['context'] != '':
|
27 |
result += 'Context: ' + test_case['context'] + '\n'
|
28 |
-
|
|
|
|
|
|
|
|
|
29 |
return result
|
30 |
|
31 |
def get_prompt_from_test_case(test_case):
|
32 |
-
|
|
|
33 |
turn_section_content=turn_section_content(test_case),
|
34 |
criteria=test_case['criteria']
|
35 |
-
))
|
36 |
-
|
37 |
-
def generate_text(prompt):
|
38 |
-
result = requests.post('http://localhost:8081/generate', json={'input': prompt}).json()
|
39 |
-
assessment = 'Yes' if result['certainty'] > 0.5 else 'No'
|
40 |
-
certainty = 1 - result['certainty'] if result['certainty'] < 0.5 else result['certainty']
|
41 |
-
certainty = f'{round(certainty,3)}'
|
42 |
-
return {'assessment': assessment, 'certainty': certainty}
|
|
|
25 |
result = ''
|
26 |
if test_case['context'] != '':
|
27 |
result += 'Context: ' + test_case['context'] + '\n'
|
28 |
+
|
29 |
+
result += 'User message: ' + test_case['user_message']
|
30 |
+
|
31 |
+
if test_case['assistant_message'] != '':
|
32 |
+
result += '\n\nAssistant message: ' + test_case['assistant_message'] + '\n'
|
33 |
return result
|
34 |
|
35 |
def get_prompt_from_test_case(test_case):
|
36 |
+
print(json.dumps(test_case, indent=4))
|
37 |
+
return assessment_prompt(assessment_prompt_content().format(
|
38 |
turn_section_content=turn_section_content(test_case),
|
39 |
criteria=test_case['criteria']
|
40 |
+
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|