asthaa30 commited on
Commit
5fb7b3d
·
verified ·
1 Parent(s): df24897

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -77
app.py CHANGED
@@ -1,92 +1,195 @@
1
- # app.py
2
-
3
  import gradio as gr
4
- from transformers import AutoTokenizer
5
- from groq import Groq
6
  import os
7
- from huggingface_hub import login
8
- import logging
9
-
10
- # Setup logging
11
- logging.basicConfig(level=logging.DEBUG)
12
-
13
- # Initialize Groq API client
14
- try:
15
- client = Groq(api_key=os.environ["GROQ_API_KEY"])
16
- logging.info("Groq API client initialized.")
17
- except KeyError:
18
- raise ValueError("GROQ_API_KEY environment variable not set.")
19
-
20
- # Load the Hugging Face token from the environment variable
21
- hf_token = os.getenv('HF_TOKEN')
22
- if hf_token is None:
23
- raise ValueError("Hugging Face token not found. Please set it as an environment variable.")
24
-
25
- # Authentication token for Hugging Face
26
- login(token=hf_token)
27
-
28
- # Model identifier for Groq API (you can replace it with your HF model if needed)
29
- model_name = "asthaa30/nomiChroma3.1"
30
-
31
- # Load tokenizer (model will be accessed via Groq API)
32
- try:
33
- tokenizer = AutoTokenizer.from_pretrained(model_name)
34
- logging.info(f"Tokenizer for model '{model_name}' loaded successfully.")
35
- except Exception as e:
36
- raise ValueError(f"Failed to load tokenizer: {e}")
37
-
38
- def respond(
39
- message: str,
40
- history: list[tuple[str, str]],
41
- system_message: str,
42
- max_tokens: int,
43
- temperature: float,
44
- top_p: float,
45
- ) -> str:
46
- messages = [{"role": "system", "content": system_message}]
47
-
48
- for user_msg, assistant_msg in history:
49
- if user_msg:
50
- messages.append({"role": "user", "content": user_msg})
51
- if assistant_msg:
52
- messages.append({"role": "assistant", "content": assistant_msg})
53
-
54
- messages.append({"role": "user", "content": message})
55
-
56
- # Use Groq API to get the model's response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  try:
58
- response = client.chat.completions.create(
59
- model=model_name,
60
- messages=messages,
61
- max_tokens=max_tokens,
62
- temperature=temperature,
63
- top_p=top_p,
 
64
  )
65
- assistant_message = response.choices[0].message['content']
66
- logging.info(f"Received response from model: {assistant_message}")
67
  except Exception as e:
68
- logging.error(f"An error occurred while getting model response: {str(e)}")
69
- assistant_message = "An error occurred. Please try again later."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- return assistant_message
 
72
 
73
  demo = gr.ChatInterface(
74
  respond,
75
  additional_inputs=[
76
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
77
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
78
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
79
- gr.Slider(
80
- minimum=0.1,
81
- maximum=1.0,
82
- value=0.95,
83
- step=0.05,
84
- label="Top-p (nucleus sampling)",
85
  ),
86
  ],
87
- title="Maritime Legal Assistant",
88
- description="This chatbot provides legal assistance related to maritime laws using a fine-tuned model hosted on Hugging Face and integrated with Groq API.",
 
89
  )
90
 
91
  if __name__ == "__main__":
92
- demo.launch(share=False) # Set share to False or remove it
 
 
 
1
  import gradio as gr
2
+ import json
 
3
  import os
4
+ from groq import Groq
5
+ from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
6
+
7
+ # Use the fine-tuned maritime legal model
8
+ MODEL = "asthaa30/nomiChroma3.1"
9
+ client = Groq(api_key=os.environ["GROQ_API_KEY"])
10
+
11
+ # Define your tools if needed (e.g., legal research, document retrieval)
12
+ def legal_tool_function(arguments):
13
+ # Implement specific legal functions here
14
+ # Placeholder for legal research or similar functionality
15
+ return {"result": "Legal tool function response here"}
16
+
17
+ # Define your tools
18
+ legal_tool: ChatCompletionToolParam = {
19
+ "type": "function",
20
+ "function": {
21
+ "name": "legal_tool_function",
22
+ "description": "Legal assistant tool: use this for various maritime legal tasks.",
23
+ "parameters": {
24
+ "type": "object",
25
+ "properties": {
26
+ "arguments": {
27
+ "type": "string",
28
+ "description": "Arguments for the legal function.",
29
+ },
30
+ },
31
+ "required": ["arguments"],
32
+ },
33
+ },
34
+ }
35
+
36
+ tools = [legal_tool]
37
+
38
+ def call_function(tool_call, available_functions):
39
+ function_name = tool_call.function.name
40
+ if function_name not in available_functions:
41
+ return {
42
+ "tool_call_id": tool_call.id,
43
+ "role": "tool",
44
+ "content": f"Function {function_name} does not exist.",
45
+ }
46
+ function_to_call = available_functions[function_name]
47
+ function_args = json.loads(tool_call.function.arguments)
48
+ function_response = function_to_call(**function_args)
49
+ return {
50
+ "tool_call_id": tool_call.id,
51
+ "role": "tool",
52
+ "name": function_name,
53
+ "content": json.dumps(function_response),
54
+ }
55
+
56
+ def get_model_response(messages, inner_messages, message, system_message):
57
+ messages_for_model = []
58
+ for msg in messages:
59
+ native_messages = msg.get("metadata", {}).get("native_messages", [msg])
60
+ if isinstance(native_messages, list):
61
+ messages_for_model.extend(native_messages)
62
+ else:
63
+ messages_for_model.append(native_messages)
64
+
65
+ messages_for_model.insert(
66
+ 0,
67
+ {
68
+ "role": "system",
69
+ "content": system_message,
70
+ },
71
+ )
72
+ messages_for_model.append(
73
+ {
74
+ "role": "user",
75
+ "content": message,
76
+ }
77
+ )
78
+ messages_for_model.extend(inner_messages)
79
+
80
  try:
81
+ return client.chat.completions.create(
82
+ model=MODEL,
83
+ messages=messages_for_model,
84
+ tools=tools,
85
+ temperature=0.5,
86
+ top_p=0.65,
87
+ max_tokens=4096,
88
  )
 
 
89
  except Exception as e:
90
+ print(f"An error occurred while getting model response: {str(e)}")
91
+ print(messages_for_model)
92
+ return None
93
+
94
+ def respond(message, history, system_message):
95
+ inner_history = []
96
+
97
+ available_functions = {
98
+ "legal_tool_function": legal_tool_function,
99
+ }
100
+
101
+ assistant_content = ""
102
+ assistant_native_message_list = []
103
+
104
+ while True:
105
+ response_message = (
106
+ get_model_response(history, inner_history, message, system_message)
107
+ .choices[0]
108
+ .message
109
+ )
110
+
111
+ if not response_message.tool_calls and response_message.content is not None:
112
+ break
113
+
114
+ if response_message.tool_calls is not None:
115
+ assistant_native_message_list.append(response_message)
116
+ inner_history.append(response_message)
117
+
118
+ assistant_content += (
119
+ "```json\n"
120
+ + json.dumps(
121
+ [
122
+ tool_call.model_dump()
123
+ for tool_call in response_message.tool_calls
124
+ ],
125
+ indent=2,
126
+ )
127
+ + "\n```\n"
128
+ )
129
+ assistant_message = {
130
+ "role": "assistant",
131
+ "content": assistant_content,
132
+ "metadata": {"native_messages": assistant_native_message_list},
133
+ }
134
+
135
+ yield assistant_message
136
+
137
+ for tool_call in response_message.tool_calls:
138
+ function_response = call_function(tool_call, available_functions)
139
+ assistant_content += (
140
+ "```json\n"
141
+ + json.dumps(
142
+ {
143
+ "name": tool_call.function.name,
144
+ "arguments": json.loads(tool_call.function.arguments),
145
+ "response": json.loads(function_response["content"]),
146
+ },
147
+ indent=2,
148
+ )
149
+ + "\n```\n"
150
+ )
151
+ native_tool_message = {
152
+ "tool_call_id": tool_call.id,
153
+ "role": "tool",
154
+ "content": function_response["content"],
155
+ }
156
+ assistant_native_message_list.append(
157
+ native_tool_message
158
+ )
159
+ tool_message = {
160
+ "role": "assistant",
161
+ "content": assistant_content,
162
+ "metadata": {"native_messages": assistant_native_message_list},
163
+ }
164
+ yield tool_message
165
+ inner_history.append(native_tool_message)
166
+
167
+ assistant_content += response_message.content
168
+ assistant_native_message_list.append(response_message)
169
+
170
+ final_message = {
171
+ "role": "assistant",
172
+ "content": assistant_content,
173
+ "metadata": {"native_messages": assistant_native_message_list},
174
+ }
175
+
176
+ yield final_message
177
 
178
+ # Update the system prompt to be more relevant to maritime legal assistance
179
+ system_prompt = "You are a maritime legal assistant with expertise in maritime law. Provide detailed legal advice and information based on maritime legal principles and regulations."
180
 
181
  demo = gr.ChatInterface(
182
  respond,
183
  additional_inputs=[
184
+ gr.Textbox(
185
+ value=system_prompt,
186
+ label="System message",
 
 
 
 
 
 
187
  ),
188
  ],
189
+ type="messages",
190
+ title="Maritime Legal Assistant Chat",
191
+ description="This chatbot uses the fine-tuned maritime legal model to provide legal assistance and information related to maritime law.",
192
  )
193
 
194
  if __name__ == "__main__":
195
+ demo.launch()