Tonic commited on
Commit
1874bf4
1 Parent(s): edc6972

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -67
app.py CHANGED
@@ -1,78 +1,87 @@
1
- from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
2
  from peft import PeftModel, PeftConfig
3
  import torch
4
  import gradio as gr
5
- import random
6
- from textwrap import wrap
 
 
7
 
8
- EXAMPLES = [
9
- ["Hey Falcon! Any recommendations for my holidays in Abu Dhabi?"],
10
- ["What's the Everett interpretation of quantum mechanics?"],
11
- ["Give me a list of the top 10 dive sites you would recommend around the world."],
12
- ["Can you tell me more about deep-water soloing?"],
13
- ["Can you write a short tweet about the release of our latest AI model, Falcon LLM?"]
14
- ]
15
 
16
 
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
  base_model_id = "tiiuae/falcon-7b-instruct"
19
  model_directory = "Tonic/GaiaMiniMed"
20
 
 
21
  tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left")
 
 
 
 
 
 
 
22
  model_config = AutoConfig.from_pretrained(base_model_id)
 
23
  peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config)
24
  peft_model = PeftModel.from_pretrained(peft_model, model_directory)
25
 
26
- def format_prompt(message, history, system_prompt):
27
- prompt = ""
28
- if system_prompt:
29
- prompt += f"System: {system_prompt}\n"
30
- for user_prompt, bot_response in history:
31
- prompt += f"User: {user_prompt}\n"
32
- prompt += f"Falcon: {bot_response}\n" # Response already contains "Falcon: "
33
- prompt += f"""User: {message}
34
- Falcon:"""
35
- return prompt
36
-
37
- seed = 42
38
-
39
- def generate(
40
- prompt, history, system_prompt="", temperature=0.9, max_new_tokens=500, top_p=0.95, repetition_penalty=1.0,
41
- ):
42
- temperature = float(temperature)
43
- if temperature < 1e-2:
44
- temperature = 1e-2
45
- top_p = float(top_p)
46
- global seed
47
- generate_kwargs = dict(
48
- temperature=temperature,
49
- max_new_tokens=max_new_tokens,
50
- top_p=top_p,
51
- repetition_penalty=1.0,
52
- stop_sequences="[END]",
53
- do_sample=True,
54
- seed=seed,
55
- )
56
- seed = seed + 1
57
- formatted_prompt = format_prompt(prompt, history, system_prompt)
58
 
59
- try:
60
- stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
61
- output = ""
62
 
63
- for response in stream:
64
- output += response.token.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- for stop_str in STOP_SEQUENCES:
67
- if output.endswith(stop_str):
68
- output = output[:-len(stop_str)]
69
- output = output.rstrip()
70
- yield output
71
- yield output
72
- except Exception as e:
73
- raise gr.Error(f"Error while generating: {e}")
74
- return output
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  additional_inputs=[
78
  gr.Textbox("", label="Optional system prompt"),
@@ -114,16 +123,19 @@ additional_inputs=[
114
  )
115
  ]
116
 
117
- with gr.Blocks() as demo:
118
- title = "👋🏻Welcome to Tonic's GaiaMiniMed🦅⚕️Falcon Chat🚀"
119
- description = "You can use this Space to test out the current model [(Tonic/GaiaMiniMed)](https://huggingface.co/Tonic/GaiaMiniMed) with chat memory optimized for falcon models or duplicate this Space and use it locally or on 🤗HuggingFace. [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."
120
-
121
- client = gr.Interface(
122
- generate,
123
- examples=EXAMPLES,
124
- additional_inputs=additional_inputs,
 
 
 
125
  theme="ParityError/Anime"
126
  )
127
 
128
- # Launch the Gradio interface
129
- client.launch(show_api=True)
 
1
+ from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM
2
  from peft import PeftModel, PeftConfig
3
  import torch
4
  import gradio as gr
5
+ import json
6
+ import os
7
+ import shutil
8
+ import requests
9
 
10
+ # Define the device
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
12
 
13
 
14
+ # Use model IDs as variables
15
  base_model_id = "tiiuae/falcon-7b-instruct"
16
  model_directory = "Tonic/GaiaMiniMed"
17
 
18
+ # Instantiate the Tokenizer
19
  tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left")
20
+ tokenizer.pad_token = tokenizer.eos_token
21
+ tokenizer.padding_side = 'left'
22
+
23
+
24
+ # Load the GaiaMiniMed model with the specified configuration
25
+ # Load the Peft model with a specific configuration
26
+ # Specify the configuration class for the model
27
  model_config = AutoConfig.from_pretrained(base_model_id)
28
+ # Load the PEFT model with the specified configuration
29
  peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config)
30
  peft_model = PeftModel.from_pretrained(peft_model, model_directory)
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
 
 
 
33
 
34
+ # Class to encapsulate the Falcon chatbot
35
+ class FalconChatBot:
36
+ def __init__(self, system_prompt="You are an expert medical analyst:"):
37
+ self.system_prompt = system_prompt
38
+
39
+ def process_history(self, history):
40
+ # Filter out special commands from the history
41
+ filtered_history = []
42
+ for message in history:
43
+ user_message = message["user"]
44
+ assistant_message = message["assistant"]
45
+ # Check if the user_message is not a special command
46
+ if not user_message.startswith("Falcon:"):
47
+ filtered_history.append({"user": user_message, "assistant": assistant_message})
48
+ return filtered_history
49
+
50
+ def predict(self, system_prompt, user_message, assistant_message, history, max_length=500):
51
+ # Process the history to remove special commands
52
+ processed_history = self.process_history(history)
53
+
54
+ # Combine the user and assistant messages into a conversation
55
+ conversation = f"{system_prompt}\nFalcon: {assistant_message if assistant_message else ''} User: {user_message}\nFalcon:\n"
56
 
57
+ # Encode the conversation using the tokenizer
58
+ input_ids = tokenizer.encode(conversation, return_tensors="pt", add_special_tokens=False)
59
+
60
+ # Generate a response using the Falcon model
61
+ response_text = peft_model.generate(input_ids, max_length=max_length, use_cache=True, early_stopping=True, bos_token_id=peft_model.config.bos_token_id, eos_token_id=peft_model.config.eos_token_id, pad_token_id=peft_model.config.eos_token_id, temperature=0.4, do_sample=True)
 
 
 
 
62
 
63
+ # Generate the formatted conversation in Falcon message format
64
+ conversation = f"{system_prompt}\n"
65
+ for message in processed_history:
66
+ user_message = message["user"]
67
+ assistant_message = message["assistant"]
68
+ conversation += f"Falcon:{' ' + assistant_message if assistant_message else ''} User: {user_message}\n Falcon:\n"
69
+
70
+ return response_text
71
+
72
+
73
+
74
+ # Create the Falcon chatbot instance
75
+ falcon_bot = FalconChatBot()
76
+
77
+ # Define the Gradio interface
78
+ title = "👋🏻Welcome to Tonic's 🦅Falcon's Medical👨🏻‍⚕️Expert Chat🚀"
79
+ description = "You can use this Space to test out the GaiaMiniMed model [(Tonic/GaiaMiniMed)](https://huggingface.co/Tonic/GaiaMiniMed) or duplicate this Space and use it locally or on 🤗HuggingFace. [Join me on Discord to build together](https://discord.gg/VqTxc76K3u)."
80
+
81
+ examples = [
82
+ ["Assistant is a public health and medical expert ready to help the user.", [{"user": "Hi there, I have a question!", "assistant": "My name is Gaia, I'm a health and sanitation expert ready to answer your medical questions."}],
83
+ ["Assistant is a public health and medical expert ready to help the user.", [{"user": "What is the proper treatment for buccal herpes?", "assistant": None}]]
84
+ ]
85
 
86
  additional_inputs=[
87
  gr.Textbox("", label="Optional system prompt"),
 
123
  )
124
  ]
125
 
126
+ iface = gr.Interface(
127
+ fn=falcon_bot.predict,
128
+ title=title,
129
+ description=description,
130
+ examples=examples,
131
+ inputs=[
132
+ gr.inputs.Textbox(label="System Prompt", type="text", lines=2),
133
+ gr.inputs.Textbox(label="User Message", type="text", lines=3),
134
+ gr.inputs.Textbox(label="Assistant Message", type="text", lines=2),
135
+ ] + additional_inputs,
136
+ outputs="text",
137
  theme="ParityError/Anime"
138
  )
139
 
140
+ # Launch the Gradio interface for the Falcon model
141
+ iface.launch()