biffboff commited on
Commit
21c4566
·
verified ·
1 Parent(s): 0301278

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -53
app.py CHANGED
@@ -4,17 +4,22 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import json
5
  from jsonschema import validate, ValidationError
6
  import logging
 
7
 
8
  # Initialize logging
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger("StrategyInterpreterSpace")
11
 
12
  # Load model and tokenizer
13
- model_name = "EleutherAI/gpt-neo-2.7B" # Updated model
14
  logger.info(f"Loading model '{model_name}'...")
15
- tokenizer = AutoTokenizer.from_pretrained(model_name)
16
- model = AutoModelForCausalLM.from_pretrained(model_name)
17
- logger.info("Model loaded successfully.")
 
 
 
 
18
 
19
  # Define JSON schema
20
  schema = {
@@ -85,41 +90,69 @@ schema = {
85
  }
86
  }
87
 
88
- def interpret_strategy(description: str) -> str:
 
89
  prompt = f"""
90
- You are an expert crypto trading assistant. Convert the following trading strategy description into a JSON format following this schema:
91
 
92
  {json.dumps(schema, indent=2)}
93
 
94
- Include all indicators (only ones available in Ta-lib and pandas-ta), their parameters (only ones that are standard for ccxt and backtrader to support), assets (only ones that are available through BitGet) as trading pairs, conditions (only those supported by bitget, backtrader, finta, pandas-ta), risk management settings, and trade execution details (only those supported by ccxt, bitget and backtrader).
95
- Response should only return the JSON with the correct parameters, nothing else.
96
  Strategy Description:
97
  {description}
98
 
99
  JSON:
100
  """
101
- inputs = tokenizer.encode(prompt, return_tensors="pt")
102
- outputs = model.generate(
103
- inputs,
104
- max_length=1000, # You can adjust this value as needed
105
- temperature=0.7,
106
- top_p=0.9,
107
- do_sample=True,
108
- eos_token_id=tokenizer.eos_token_id,
109
- )
110
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
111
- response_text = generated_text[len(prompt):].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  # Validate JSON
114
  try:
115
  strategy_data = json.loads(response_text)
116
  validate(instance=strategy_data, schema=schema)
117
- return json.dumps(strategy_data, indent=2)
118
- except (json.JSONDecodeError, ValidationError) as e:
119
- logger.error(f"Error interpreting strategy: {e}")
120
- return f"Error interpreting strategy: {e}"
121
-
122
- def suggest_strategy(risk_level: str, market_type: str) -> str:
 
 
 
 
 
 
123
  prompt = f"""Please create a unique crypto trading strategy suitable for a '{risk_level}' risk appetite in the '{market_type}' market.
124
  Ensure the JSON matches this schema:
125
  {json.dumps(schema, indent=2)}
@@ -127,18 +160,38 @@ Ensure the JSON matches this schema:
127
  Use indicators and conditions that can be applied by ccxt, bitget, pandas-ta, and backtrader.
128
 
129
  JSON:"""
 
130
 
131
- inputs = tokenizer.encode(prompt, return_tensors="pt")
132
- outputs = model.generate(
133
- inputs,
134
- max_length=1000, # You can adjust this value as needed
135
- temperature=0.7,
136
- top_p=0.9,
137
- do_sample=True,
138
- eos_token_id=tokenizer.eos_token_id,
139
- )
140
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
141
- response_text = generated_text[len(prompt):].strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  # Validate JSON
144
  try:
@@ -146,15 +199,23 @@ JSON:"""
146
  validate(instance=strategy_data, schema=schema)
147
  if strategy_data.get("market_type") != market_type:
148
  raise ValueError("The generated strategy's market type does not match the selected market type.")
149
- return json.dumps(strategy_data, indent=2)
150
- except (json.JSONDecodeError, ValidationError, ValueError) as e:
151
- logger.error(f"Error generating strategy: {e}")
152
- return f"Error generating strategy: {e}"
 
 
 
 
 
 
 
 
153
 
154
  iface_interpret = gr.Interface(
155
  fn=interpret_strategy,
156
- inputs=gr.Textbox(lines=10, placeholder="Enter your strategy description here..."),
157
- outputs="text",
158
  title="Strategy Interpreter",
159
  description="Convert trading strategy descriptions into structured JSON format."
160
  )
@@ -163,14 +224,4 @@ iface_suggest = gr.Interface(
163
  fn=suggest_strategy,
164
  inputs=[
165
  gr.Textbox(lines=1, placeholder="Enter risk level (e.g., medium)...", label="Risk Level"),
166
- gr.Textbox(lines=1, placeholder="Enter market type (e.g., spot)...", label="Market Type")
167
- ],
168
- outputs="text",
169
- title="Strategy Suggester",
170
- description="Generate a unique trading strategy based on risk level and market type."
171
- )
172
-
173
- app = gr.TabbedInterface([iface_interpret, iface_suggest], ["Interpret Strategy", "Suggest Strategy"])
174
-
175
- if __name__ == "__main__":
176
- app.launch()
 
4
  import json
5
  from jsonschema import validate, ValidationError
6
  import logging
7
+ import torch
8
 
9
  # Initialize logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger("StrategyInterpreterSpace")
12
 
13
  # Load model and tokenizer
14
+ model_name = "EleutherAI/gpt-neo-2.7B" # Using a smaller model to fit within memory constraints
15
  logger.info(f"Loading model '{model_name}'...")
16
+ try:
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+ model = AutoModelForCausalLM.from_pretrained(model_name)
19
+ logger.info("Model loaded successfully.")
20
+ except Exception as e:
21
+ logger.error(f"Failed to load model: {e}")
22
+ raise e
23
 
24
  # Define JSON schema
25
  schema = {
 
90
  }
91
  }
92
 
93
+ def interpret_strategy(description: str) -> dict:
94
+ logger.info("Received strategy description for interpretation.")
95
  prompt = f"""
96
+ You are an expert crypto trading assistant. Convert the following trading strategy description into a JSON format strictly following this schema:
97
 
98
  {json.dumps(schema, indent=2)}
99
 
100
+ Ensure that the response contains only valid JSON with the correct parameters. Do not include any additional text or explanations.
101
+
102
  Strategy Description:
103
  {description}
104
 
105
  JSON:
106
  """
107
+ logger.debug(f"Prompt constructed: {prompt}")
108
+
109
+ try:
110
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
111
+ logger.info("Tokenized the input prompt.")
112
+ except Exception as e:
113
+ logger.error(f"Error during tokenization: {e}")
114
+ return {"error": f"Error during tokenization: {e}"}
115
+
116
+ try:
117
+ with torch.no_grad():
118
+ outputs = model.generate(
119
+ inputs,
120
+ max_length=800, # Reduced max_length to prevent overly long outputs
121
+ temperature=0.7,
122
+ top_p=0.9,
123
+ do_sample=True,
124
+ eos_token_id=tokenizer.eos_token_id,
125
+ )
126
+ logger.info("Model generated output.")
127
+ except Exception as e:
128
+ logger.error(f"Error during model generation: {e}")
129
+ return {"error": f"Error during model generation: {e}"}
130
+
131
+ try:
132
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
133
+ logger.debug(f"Generated text: {generated_text}")
134
+ response_text = generated_text[len(prompt):].strip()
135
+ logger.debug(f"Response text after prompt removal: {response_text}")
136
+ except Exception as e:
137
+ logger.error(f"Error during decoding: {e}")
138
+ return {"error": f"Error during decoding: {e}"}
139
 
140
  # Validate JSON
141
  try:
142
  strategy_data = json.loads(response_text)
143
  validate(instance=strategy_data, schema=schema)
144
+ logger.info("Strategy interpreted successfully and validated against schema.")
145
+ return strategy_data # Return as dict for Gradio's JSON output
146
+ except json.JSONDecodeError as e:
147
+ logger.error(f"JSON decoding error: {e}")
148
+ # Return raw text for debugging
149
+ return {"error": f"Error interpreting strategy: Invalid JSON format.\nGenerated Text:\n{response_text}\nDetails: {e}"}
150
+ except ValidationError as e:
151
+ logger.error(f"JSON validation error: {e}")
152
+ return {"error": f"Error interpreting strategy: JSON does not conform to schema.\nDetails: {e}"}
153
+
154
+ def suggest_strategy(risk_level: str, market_type: str) -> dict:
155
+ logger.info("Received request to suggest a new strategy.")
156
  prompt = f"""Please create a unique crypto trading strategy suitable for a '{risk_level}' risk appetite in the '{market_type}' market.
157
  Ensure the JSON matches this schema:
158
  {json.dumps(schema, indent=2)}
 
160
  Use indicators and conditions that can be applied by ccxt, bitget, pandas-ta, and backtrader.
161
 
162
  JSON:"""
163
+ logger.debug(f"Prompt constructed for strategy suggestion: {prompt}")
164
 
165
+ try:
166
+ inputs = tokenizer.encode(prompt, return_tensors="pt")
167
+ logger.info("Tokenized the suggestion prompt.")
168
+ except Exception as e:
169
+ logger.error(f"Error during tokenization: {e}")
170
+ return {"error": f"Error during tokenization: {e}"}
171
+
172
+ try:
173
+ with torch.no_grad():
174
+ outputs = model.generate(
175
+ inputs,
176
+ max_length=800, # Reduced max_length to prevent overly long outputs
177
+ temperature=0.7,
178
+ top_p=0.9,
179
+ do_sample=True,
180
+ eos_token_id=tokenizer.eos_token_id,
181
+ )
182
+ logger.info("Model generated suggestion output.")
183
+ except Exception as e:
184
+ logger.error(f"Error during model generation: {e}")
185
+ return {"error": f"Error during model generation: {e}"}
186
+
187
+ try:
188
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
189
+ logger.debug(f"Generated suggestion text: {generated_text}")
190
+ response_text = generated_text[len(prompt):].strip()
191
+ logger.debug(f"Suggestion response text after prompt removal: {response_text}")
192
+ except Exception as e:
193
+ logger.error(f"Error during decoding: {e}")
194
+ return {"error": f"Error during decoding: {e}"}
195
 
196
  # Validate JSON
197
  try:
 
199
  validate(instance=strategy_data, schema=schema)
200
  if strategy_data.get("market_type") != market_type:
201
  raise ValueError("The generated strategy's market type does not match the selected market type.")
202
+ logger.info("Strategy suggested successfully and validated against schema.")
203
+ return strategy_data # Return as dict for Gradio's JSON output
204
+ except json.JSONDecodeError as e:
205
+ logger.error(f"JSON decoding error: {e}")
206
+ # Return raw text for debugging
207
+ return {"error": f"Error generating strategy: Invalid JSON format.\nGenerated Text:\n{response_text}\nDetails: {e}"}
208
+ except ValidationError as e:
209
+ logger.error(f"JSON validation error: {e}")
210
+ return {"error": f"Error generating strategy: JSON does not conform to schema.\nDetails: {e}"}
211
+ except ValueError as e:
212
+ logger.error(f"Market type mismatch error: {e}")
213
+ return {"error": f"Error generating strategy: {e}"}
214
 
215
  iface_interpret = gr.Interface(
216
  fn=interpret_strategy,
217
+ inputs=gr.Textbox(lines=10, placeholder="Enter your strategy description here...", label="Strategy Description"),
218
+ outputs=gr.JSON(label="Interpreted Strategy"),
219
  title="Strategy Interpreter",
220
  description="Convert trading strategy descriptions into structured JSON format."
221
  )
 
224
  fn=suggest_strategy,
225
  inputs=[
226
  gr.Textbox(lines=1, placeholder="Enter risk level (e.g., medium)...", label="Risk Level"),
227
+