Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,17 +4,22 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
4 |
import json
|
5 |
from jsonschema import validate, ValidationError
|
6 |
import logging
|
|
|
7 |
|
8 |
# Initialize logging
|
9 |
logging.basicConfig(level=logging.INFO)
|
10 |
logger = logging.getLogger("StrategyInterpreterSpace")
|
11 |
|
12 |
# Load model and tokenizer
|
13 |
-
model_name = "EleutherAI/gpt-neo-2.7B" #
|
14 |
logger.info(f"Loading model '{model_name}'...")
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# Define JSON schema
|
20 |
schema = {
|
@@ -85,41 +90,69 @@ schema = {
|
|
85 |
}
|
86 |
}
|
87 |
|
88 |
-
def interpret_strategy(description: str) ->
|
|
|
89 |
prompt = f"""
|
90 |
-
You are an expert crypto trading assistant. Convert the following trading strategy description into a JSON format following this schema:
|
91 |
|
92 |
{json.dumps(schema, indent=2)}
|
93 |
|
94 |
-
|
95 |
-
|
96 |
Strategy Description:
|
97 |
{description}
|
98 |
|
99 |
JSON:
|
100 |
"""
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
|
113 |
# Validate JSON
|
114 |
try:
|
115 |
strategy_data = json.loads(response_text)
|
116 |
validate(instance=strategy_data, schema=schema)
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
prompt = f"""Please create a unique crypto trading strategy suitable for a '{risk_level}' risk appetite in the '{market_type}' market.
|
124 |
Ensure the JSON matches this schema:
|
125 |
{json.dumps(schema, indent=2)}
|
@@ -127,18 +160,38 @@ Ensure the JSON matches this schema:
|
|
127 |
Use indicators and conditions that can be applied by ccxt, bitget, pandas-ta, and backtrader.
|
128 |
|
129 |
JSON:"""
|
|
|
130 |
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
# Validate JSON
|
144 |
try:
|
@@ -146,15 +199,23 @@ JSON:"""
|
|
146 |
validate(instance=strategy_data, schema=schema)
|
147 |
if strategy_data.get("market_type") != market_type:
|
148 |
raise ValueError("The generated strategy's market type does not match the selected market type.")
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
|
154 |
iface_interpret = gr.Interface(
|
155 |
fn=interpret_strategy,
|
156 |
-
inputs=gr.Textbox(lines=10, placeholder="Enter your strategy description here..."),
|
157 |
-
outputs="
|
158 |
title="Strategy Interpreter",
|
159 |
description="Convert trading strategy descriptions into structured JSON format."
|
160 |
)
|
@@ -163,14 +224,4 @@ iface_suggest = gr.Interface(
|
|
163 |
fn=suggest_strategy,
|
164 |
inputs=[
|
165 |
gr.Textbox(lines=1, placeholder="Enter risk level (e.g., medium)...", label="Risk Level"),
|
166 |
-
|
167 |
-
],
|
168 |
-
outputs="text",
|
169 |
-
title="Strategy Suggester",
|
170 |
-
description="Generate a unique trading strategy based on risk level and market type."
|
171 |
-
)
|
172 |
-
|
173 |
-
app = gr.TabbedInterface([iface_interpret, iface_suggest], ["Interpret Strategy", "Suggest Strategy"])
|
174 |
-
|
175 |
-
if __name__ == "__main__":
|
176 |
-
app.launch()
|
|
|
4 |
import json
|
5 |
from jsonschema import validate, ValidationError
|
6 |
import logging
|
7 |
+
import torch
|
8 |
|
9 |
# Initialize logging
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
logger = logging.getLogger("StrategyInterpreterSpace")
|
12 |
|
13 |
# Load model and tokenizer
|
14 |
+
model_name = "EleutherAI/gpt-neo-2.7B" # Using a smaller model to fit within memory constraints
|
15 |
logger.info(f"Loading model '{model_name}'...")
|
16 |
+
try:
|
17 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
+
model = AutoModelForCausalLM.from_pretrained(model_name)
|
19 |
+
logger.info("Model loaded successfully.")
|
20 |
+
except Exception as e:
|
21 |
+
logger.error(f"Failed to load model: {e}")
|
22 |
+
raise e
|
23 |
|
24 |
# Define JSON schema
|
25 |
schema = {
|
|
|
90 |
}
|
91 |
}
|
92 |
|
93 |
+
def interpret_strategy(description: str) -> dict:
|
94 |
+
logger.info("Received strategy description for interpretation.")
|
95 |
prompt = f"""
|
96 |
+
You are an expert crypto trading assistant. Convert the following trading strategy description into a JSON format strictly following this schema:
|
97 |
|
98 |
{json.dumps(schema, indent=2)}
|
99 |
|
100 |
+
Ensure that the response contains only valid JSON with the correct parameters. Do not include any additional text or explanations.
|
101 |
+
|
102 |
Strategy Description:
|
103 |
{description}
|
104 |
|
105 |
JSON:
|
106 |
"""
|
107 |
+
logger.debug(f"Prompt constructed: {prompt}")
|
108 |
+
|
109 |
+
try:
|
110 |
+
inputs = tokenizer.encode(prompt, return_tensors="pt")
|
111 |
+
logger.info("Tokenized the input prompt.")
|
112 |
+
except Exception as e:
|
113 |
+
logger.error(f"Error during tokenization: {e}")
|
114 |
+
return {"error": f"Error during tokenization: {e}"}
|
115 |
+
|
116 |
+
try:
|
117 |
+
with torch.no_grad():
|
118 |
+
outputs = model.generate(
|
119 |
+
inputs,
|
120 |
+
max_length=800, # Reduced max_length to prevent overly long outputs
|
121 |
+
temperature=0.7,
|
122 |
+
top_p=0.9,
|
123 |
+
do_sample=True,
|
124 |
+
eos_token_id=tokenizer.eos_token_id,
|
125 |
+
)
|
126 |
+
logger.info("Model generated output.")
|
127 |
+
except Exception as e:
|
128 |
+
logger.error(f"Error during model generation: {e}")
|
129 |
+
return {"error": f"Error during model generation: {e}"}
|
130 |
+
|
131 |
+
try:
|
132 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
133 |
+
logger.debug(f"Generated text: {generated_text}")
|
134 |
+
response_text = generated_text[len(prompt):].strip()
|
135 |
+
logger.debug(f"Response text after prompt removal: {response_text}")
|
136 |
+
except Exception as e:
|
137 |
+
logger.error(f"Error during decoding: {e}")
|
138 |
+
return {"error": f"Error during decoding: {e}"}
|
139 |
|
140 |
# Validate JSON
|
141 |
try:
|
142 |
strategy_data = json.loads(response_text)
|
143 |
validate(instance=strategy_data, schema=schema)
|
144 |
+
logger.info("Strategy interpreted successfully and validated against schema.")
|
145 |
+
return strategy_data # Return as dict for Gradio's JSON output
|
146 |
+
except json.JSONDecodeError as e:
|
147 |
+
logger.error(f"JSON decoding error: {e}")
|
148 |
+
# Return raw text for debugging
|
149 |
+
return {"error": f"Error interpreting strategy: Invalid JSON format.\nGenerated Text:\n{response_text}\nDetails: {e}"}
|
150 |
+
except ValidationError as e:
|
151 |
+
logger.error(f"JSON validation error: {e}")
|
152 |
+
return {"error": f"Error interpreting strategy: JSON does not conform to schema.\nDetails: {e}"}
|
153 |
+
|
154 |
+
def suggest_strategy(risk_level: str, market_type: str) -> dict:
|
155 |
+
logger.info("Received request to suggest a new strategy.")
|
156 |
prompt = f"""Please create a unique crypto trading strategy suitable for a '{risk_level}' risk appetite in the '{market_type}' market.
|
157 |
Ensure the JSON matches this schema:
|
158 |
{json.dumps(schema, indent=2)}
|
|
|
160 |
Use indicators and conditions that can be applied by ccxt, bitget, pandas-ta, and backtrader.
|
161 |
|
162 |
JSON:"""
|
163 |
+
logger.debug(f"Prompt constructed for strategy suggestion: {prompt}")
|
164 |
|
165 |
+
try:
|
166 |
+
inputs = tokenizer.encode(prompt, return_tensors="pt")
|
167 |
+
logger.info("Tokenized the suggestion prompt.")
|
168 |
+
except Exception as e:
|
169 |
+
logger.error(f"Error during tokenization: {e}")
|
170 |
+
return {"error": f"Error during tokenization: {e}"}
|
171 |
+
|
172 |
+
try:
|
173 |
+
with torch.no_grad():
|
174 |
+
outputs = model.generate(
|
175 |
+
inputs,
|
176 |
+
max_length=800, # Reduced max_length to prevent overly long outputs
|
177 |
+
temperature=0.7,
|
178 |
+
top_p=0.9,
|
179 |
+
do_sample=True,
|
180 |
+
eos_token_id=tokenizer.eos_token_id,
|
181 |
+
)
|
182 |
+
logger.info("Model generated suggestion output.")
|
183 |
+
except Exception as e:
|
184 |
+
logger.error(f"Error during model generation: {e}")
|
185 |
+
return {"error": f"Error during model generation: {e}"}
|
186 |
+
|
187 |
+
try:
|
188 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
189 |
+
logger.debug(f"Generated suggestion text: {generated_text}")
|
190 |
+
response_text = generated_text[len(prompt):].strip()
|
191 |
+
logger.debug(f"Suggestion response text after prompt removal: {response_text}")
|
192 |
+
except Exception as e:
|
193 |
+
logger.error(f"Error during decoding: {e}")
|
194 |
+
return {"error": f"Error during decoding: {e}"}
|
195 |
|
196 |
# Validate JSON
|
197 |
try:
|
|
|
199 |
validate(instance=strategy_data, schema=schema)
|
200 |
if strategy_data.get("market_type") != market_type:
|
201 |
raise ValueError("The generated strategy's market type does not match the selected market type.")
|
202 |
+
logger.info("Strategy suggested successfully and validated against schema.")
|
203 |
+
return strategy_data # Return as dict for Gradio's JSON output
|
204 |
+
except json.JSONDecodeError as e:
|
205 |
+
logger.error(f"JSON decoding error: {e}")
|
206 |
+
# Return raw text for debugging
|
207 |
+
return {"error": f"Error generating strategy: Invalid JSON format.\nGenerated Text:\n{response_text}\nDetails: {e}"}
|
208 |
+
except ValidationError as e:
|
209 |
+
logger.error(f"JSON validation error: {e}")
|
210 |
+
return {"error": f"Error generating strategy: JSON does not conform to schema.\nDetails: {e}"}
|
211 |
+
except ValueError as e:
|
212 |
+
logger.error(f"Market type mismatch error: {e}")
|
213 |
+
return {"error": f"Error generating strategy: {e}"}
|
214 |
|
215 |
iface_interpret = gr.Interface(
|
216 |
fn=interpret_strategy,
|
217 |
+
inputs=gr.Textbox(lines=10, placeholder="Enter your strategy description here...", label="Strategy Description"),
|
218 |
+
outputs=gr.JSON(label="Interpreted Strategy"),
|
219 |
title="Strategy Interpreter",
|
220 |
description="Convert trading strategy descriptions into structured JSON format."
|
221 |
)
|
|
|
224 |
fn=suggest_strategy,
|
225 |
inputs=[
|
226 |
gr.Textbox(lines=1, placeholder="Enter risk level (e.g., medium)...", label="Risk Level"),
|
227 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|