nmarafo commited on
Commit
719e873
verified
1 Parent(s): f6024ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -101
app.py CHANGED
@@ -1,110 +1,107 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
- from torch.nn.functional import softmax
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Cargar el modelo cuantizado y el tokenizer de google-shieldgemma-2b
 
 
7
  model = AutoModelForCausalLM.from_pretrained(
8
- "PrunaAI/google-shieldgemma-2b-bnb-4bit-smashed",
9
- trust_remote_code=True,
10
- device_map='auto'
 
 
11
  )
12
- tokenizer = AutoTokenizer.from_pretrained("google/shieldgemma-2b")
13
-
14
- # Funci贸n para generar el prompt dependiendo del idioma seleccionado y el tipo de contenido
15
- def generar_prompt(message, tipo_clasificacion, idioma):
16
- if idioma == "Espa帽ol":
17
- politicas = {
18
- "Contenido Peligroso": """
19
- * "No al contenido peligroso": El mensaje no debe contener ni buscar la generaci贸n de contenido que cause da帽o a uno mismo o a otros (por ejemplo, acceder o fabricar armas de fuego o explosivos, promoci贸n del terrorismo, instrucciones para el suicidio).
20
- """,
21
- "Acoso": """
22
- * "No al acoso": El mensaje no debe contener ni buscar la generaci贸n de contenido que sea malicioso, intimidante, que fomente el bullying o abusivo dirigido a otra persona (por ejemplo, amenazas f铆sicas, negaci贸n de eventos tr谩gicos, menosprecio a v铆ctimas de violencia).
23
- """,
24
- "Discurso de Odio": """
25
- * "No al discurso de odio": El mensaje no debe contener ni buscar la generaci贸n de contenido que ataque la identidad o atributos protegidos (por ejemplo, insultos raciales, promoci贸n de la discriminaci贸n, llamados a la violencia contra grupos protegidos) o que deshumanice, menosprecie o vilipendie por motivos de raza, etnia, religi贸n, discapacidad, edad, nacionalidad, estatus de veterano, orientaci贸n sexual, g茅nero, identidad de g茅nero, casta, u otro estatus protegido.
26
- """,
27
- "Contenido Sexual": """
28
- * "No al contenido sexual expl铆cito": El mensaje no debe contener ni buscar la generaci贸n de contenido que haga referencia a actos sexuales u otro contenido lascivo (por ejemplo, descripciones gr谩ficas sexuales, contenido destinado a causar excitaci贸n). Los t茅rminos m茅dicos o cient铆ficos relacionados con la anatom铆a humana o la educaci贸n sexual est谩n permitidos.
29
- """
30
- }
31
- else: # English
32
- politicas = {
33
- "Dangerous Content": """
34
- * "No Dangerous Content": The prompt shall not contain or seek generation of content that harms oneself and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).
35
- """,
36
- "Harassment": """
37
- * "No Harassment": The prompt shall not contain or seek generation of content that is malicious, intimidating, bullying, or abusive content targeting another individual (e.g., physical threats, denial of tragic events, disparaging victims of violence).
38
- """,
39
- "Hate Speech": """
40
- * "No Hate Speech": The prompt shall not contain or seek generation of content that targets identity and/or protected attributes (e.g., racial slurs, promotion of discrimination, calls to violence against protected groups) or dehumanizing/belittling/vilifying based on race, ethnicity, religion, disability, age, nationality, veteran status, sexual orientation, gender, gender identity, caste, or any other protected status.
41
- """,
42
- "Sexually Explicit Information": """
43
- * "No Sexually Explicit Information": The prompt shall not contain or seek generation of content that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions, content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted.
44
- """
45
- }
46
-
47
- # Formatear el prompt seg煤n el tipo de clasificaci贸n seleccionado
48
- politica_seguridad = politicas.get(tipo_clasificacion, "")
49
-
50
- if idioma == "Espa帽ol":
51
- prompt = f"""Eres un experto en pol铆ticas tratando de determinar si un mensaje del usuario
52
- viola las pol铆ticas de seguridad definidas.
53
-
54
- <inicio_turno>
55
- Pregunta del usuario: {message.strip()}
56
- <fin_turno>
57
-
58
- Nuestro principio de seguridad se define a continuaci贸n:
59
-
60
- {politica_seguridad.strip()}
61
-
62
- 驴La pregunta del usuario viola el principio anterior? Tu respuesta debe comenzar con
63
- 'S铆' o 'No'. Luego, explica paso a paso para asegurarte de que respondemos correctamente.
64
- """
65
- else: # Default to English
66
- prompt = f"""You are a policy expert trying to help determine whether a user prompt is in
67
- violation of the defined safety policies.
68
-
69
- <start_of_turn>
70
- Human Question: {message.strip()}
71
- <end_of_turn>
72
-
73
- Our safety principle is defined below:
74
-
75
- {politica_seguridad.strip()}
76
-
77
- Does the human question violate the above principle? Your answer must start with
78
- 'Yes' or 'No'. Then walk through step by step to be sure we answer correctly.
79
- """
80
-
81
- return prompt
82
-
83
- # Funci贸n para procesar la respuesta y clasificar seg煤n las pol铆ticas de seguridad
84
- def respond(message, history: list[tuple[str, str]], system_message, max_tokens, temperature, top_p, language, harm_type):
85
- prompt = generar_prompt(message, harm_type, language)
86
-
87
- # Tokenizar el mensaje de entrada
88
- input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)["input_ids"]
89
-
90
- # Generar la salida con el modelo cuantizado
91
- outputs = model.generate(input_ids, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p)
92
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
93
-
94
- return response
95
-
96
- # Crear la interfaz de Gradio con selecci贸n de idioma y tipo de contenido
97
  demo = gr.ChatInterface(
98
- respond,
99
- additional_inputs=[
100
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
101
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
102
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
103
- gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
104
- gr.Dropdown(choices=["English", "Espa帽ol"], value="English", label="Idioma/Language"),
105
- gr.Dropdown(choices=["Dangerous Content", "Harassment", "Hate Speech", "Sexually Explicit Information"], value="Harassment", label="Harm Type")
106
- ],
 
 
 
 
107
  )
108
 
 
109
  if __name__ == "__main__":
110
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
3
  import torch
4
+ from threading import Thread
5
+ import os
6
+
7
+ # Cargar el token de Hugging Face desde los secretos
8
+ token = os.environ["HF_TOKEN"]
9
+
10
+ # Configurar la cuantizaci贸n con bitsandbytes para reducir el uso de memoria
11
+ bnb_config = BitsAndBytesConfig(
12
+ load_in_4bit=True,
13
+ bnb_4bit_use_double_quant=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_compute_dtype=torch.bfloat16
16
+ )
17
 
18
+ # Cargar el modelo cuantizado y el tokenizer
19
+ model_id = "PrunaAI/google-shieldgemma-2b-bnb-4bit-smashed"
20
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=token)
21
  model = AutoModelForCausalLM.from_pretrained(
22
+ model_id,
23
+ torch_dtype=torch.bfloat16,
24
+ device_map="auto",
25
+ quantization_config=bnb_config,
26
+ token=token
27
  )
28
+
29
+ # Definir terminadores de secuencia
30
+ terminators = [
31
+ tokenizer.eos_token_id,
32
+ tokenizer.convert_tokens_to_ids("<|eot_id|>")
33
+ ]
34
+
35
+ # Mensaje del sistema (system message)
36
+ SYS_PROMPT = """Eres un asistente que responde preguntas de forma conversacional.
37
+ Se te proporciona una pregunta y contexto adicional. Proporciona una respuesta clara y precisa.
38
+ Si no sabes la respuesta, simplemente di "No lo s茅". No inventes una respuesta."""
39
+
40
+ # Funci贸n principal para manejar la conversaci贸n
41
+ def talk(prompt, history):
42
+ formatted_prompt = f"Pregunta: {prompt}\nContexto: {SYS_PROMPT}"
43
+ formatted_prompt = formatted_prompt[:2000] # Limitar a 2000 caracteres para evitar problemas de OOM
44
+
45
+ # Preparar los mensajes para el modelo
46
+ messages = [{"role": "system", "content": SYS_PROMPT}, {"role": "user", "content": formatted_prompt}]
47
+
48
+ # Tokenizar el prompt
49
+ input_ids = tokenizer.apply_chat_template(
50
+ messages,
51
+ add_generation_prompt=True,
52
+ return_tensors="pt"
53
+ ).to(model.device)
54
+
55
+ # Configurar el generador de texto con streaming
56
+ streamer = TextIteratorStreamer(
57
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
58
+ )
59
+
60
+ # Configurar los argumentos para la generaci贸n
61
+ generate_kwargs = dict(
62
+ input_ids=input_ids,
63
+ streamer=streamer,
64
+ max_new_tokens=512, # Reducido para evitar OOM
65
+ do_sample=True,
66
+ top_p=0.95,
67
+ temperature=0.75,
68
+ eos_token_id=terminators,
69
+ )
70
+
71
+ # Iniciar el hilo para la generaci贸n de texto
72
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
73
+ t.start()
74
+
75
+ # Recoger los resultados de forma incremental
76
+ outputs = []
77
+ for text in streamer:
78
+ outputs.append(text)
79
+ yield "".join(outputs)
80
+
81
+
82
+ # Configuraci贸n de la interfaz de Gradio
83
+ TITLE = "# Chatbot de Respuestas"
84
+ DESCRIPTION = """
85
+ Este chatbot responde preguntas de manera conversacional usando un modelo cuantizado.
86
+ """
87
+
88
+ # Crear la interfaz del chatbot en Gradio
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  demo = gr.ChatInterface(
90
+ fn=talk,
91
+ chatbot=gr.Chatbot(
92
+ show_label=True,
93
+ show_share_button=True,
94
+ show_copy_button=True,
95
+ likeable=True,
96
+ layout="bubble",
97
+ bubble_full_width=False,
98
+ ),
99
+ theme="Soft",
100
+ examples=[["驴Qu茅 es la anarqu铆a?"]],
101
+ title=TITLE,
102
+ description=DESCRIPTION,
103
  )
104
 
105
+ # Lanzar la interfaz de Gradio
106
  if __name__ == "__main__":
107
+ demo.launch(debug=True)