biffboff's picture
fixed [ syntax error
92b21fc verified
# app.py
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
from jsonschema import validate, ValidationError
import logging
import torch
# Initialize logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("StrategyInterpreterSpace")
# Load model and tokenizer
model_name = "EleutherAI/gpt-neo-2.7B" # Using a smaller model to fit within memory constraints
logger.info(f"Loading model '{model_name}'...")
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise e
# Define JSON schema
schema = {
"type": "object",
"required": [
"strategy_name",
"market_type",
"assets",
"trade_parameters",
"conditions",
"risk_management"
],
"properties": {
"strategy_name": {"type": "string"},
"market_type": {"type": "string", "enum": ["spot", "futures", "margin"]},
"assets": {"type": "array", "items": {"type": "string"}},
"trade_parameters": {
"type": "object",
"required": ["leverage", "order_type", "position_size"],
"properties": {
"leverage": {"type": "number"},
"order_type": {"type": "string"},
"position_size": {"type": "number"}
}
},
"conditions": {
"type": "object",
"required": ["entry", "exit"],
"properties": {
"entry": {
"type": "array",
"items": {"$ref": "#/definitions/condition"}
},
"exit": {
"type": "array",
"items": {"$ref": "#/definitions/condition"}
}
}
},
"risk_management": {
"type": "object",
"required": ["stop_loss", "take_profit", "trailing_stop_loss"],
"properties": {
"stop_loss": {"type": "number"},
"take_profit": {"type": "number"},
"trailing_stop_loss": {"type": "number"}
}
}
},
"definitions": {
"condition": {
"type": "object",
"required": ["indicator", "operator", "value", "timeframe"],
"properties": {
"indicator": {"type": "string"},
"operator": {"type": "string", "enum": [">", "<", "==", ">=", "<="]},
"value": {"type": ["string", "number"]},
"timeframe": {"type": "string"},
"indicator_parameters": {
"type": "object",
"properties": {
"period": {"type": "number"},
},
"additionalProperties": True
}
}
}
}
}
def interpret_strategy(description: str) -> dict:
logger.info("Received strategy description for interpretation.")
prompt = f"""
You are an expert crypto trading assistant. Convert the following trading strategy description into a JSON format strictly following this schema:
{json.dumps(schema, indent=2)}
Ensure that the response contains only valid JSON with the correct parameters. Do not include any additional text or explanations.
Strategy Description:
{description}
JSON:
"""
logger.debug(f"Prompt constructed: {prompt}")
try:
inputs = tokenizer.encode(prompt, return_tensors="pt")
logger.info("Tokenized the input prompt.")
except Exception as e:
logger.error(f"Error during tokenization: {e}")
return {"error": f"Error during tokenization: {e}"}
try:
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=800, # Reduced max_length to prevent overly long outputs
temperature=0.7,
top_p=0.9,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
)
logger.info("Model generated output.")
except Exception as e:
logger.error(f"Error during model generation: {e}")
return {"error": f"Error during model generation: {e}"}
try:
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.debug(f"Generated text: {generated_text}")
response_text = generated_text[len(prompt):].strip()
logger.debug(f"Response text after prompt removal: {response_text}")
except Exception as e:
logger.error(f"Error during decoding: {e}")
return {"error": f"Error during decoding: {e}"}
# Validate JSON
try:
strategy_data = json.loads(response_text)
validate(instance=strategy_data, schema=schema)
logger.info("Strategy interpreted successfully and validated against schema.")
return strategy_data # Return as dict for Gradio's JSON output
except json.JSONDecodeError as e:
logger.error(f"JSON decoding error: {e}")
# Return raw text for debugging
return {"error": f"Error interpreting strategy: Invalid JSON format.\nGenerated Text:\n{response_text}\nDetails: {e}"}
except ValidationError as e:
logger.error(f"JSON validation error: {e}")
return {"error": f"Error interpreting strategy: JSON does not conform to schema.\nDetails: {e}"}
def suggest_strategy(risk_level: str, market_type: str) -> dict:
logger.info("Received request to suggest a new strategy.")
prompt = f"""Please create a unique crypto trading strategy suitable for a '{risk_level}' risk appetite in the '{market_type}' market.
Ensure the JSON matches this schema:
{json.dumps(schema, indent=2)}
Use indicators and conditions that can be applied by ccxt, bitget, pandas-ta, and backtrader.
JSON:"""
logger.debug(f"Prompt constructed for strategy suggestion: {prompt}")
try:
inputs = tokenizer.encode(prompt, return_tensors="pt")
logger.info("Tokenized the suggestion prompt.")
except Exception as e:
logger.error(f"Error during tokenization: {e}")
return {"error": f"Error during tokenization: {e}"}
try:
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=800, # Reduced max_length to prevent overly long outputs
temperature=0.7,
top_p=0.9,
do_sample=True,
eos_token_id=tokenizer.eos_token_id,
)
logger.info("Model generated suggestion output.")
except Exception as e:
logger.error(f"Error during model generation: {e}")
return {"error": f"Error during model generation: {e}"}
try:
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
logger.debug(f"Generated suggestion text: {generated_text}")
response_text = generated_text[len(prompt):].strip()
logger.debug(f"Suggestion response text after prompt removal: {response_text}")
except Exception as e:
logger.error(f"Error during decoding: {e}")
return {"error": f"Error during decoding: {e}"}
# Validate JSON
try:
strategy_data = json.loads(response_text)
validate(instance=strategy_data, schema=schema)
if strategy_data.get("market_type") != market_type:
raise ValueError("The generated strategy's market type does not match the selected market type.")
logger.info("Strategy suggested successfully and validated against schema.")
return strategy_data # Return as dict for Gradio's JSON output
except json.JSONDecodeError as e:
logger.error(f"JSON decoding error: {e}")
# Return raw text for debugging
return {"error": f"Error generating strategy: Invalid JSON format.\nGenerated Text:\n{response_text}\nDetails: {e}"}
except ValidationError as e:
logger.error(f"JSON validation error: {e}")
return {"error": f"Error generating strategy: JSON does not conform to schema.\nDetails: {e}"}
except ValueError as e:
logger.error(f"Market type mismatch error: {e}")
return {"error": f"Error generating strategy: {e}"}
iface_interpret = gr.Interface(
fn=interpret_strategy,
inputs=gr.Textbox(lines=10, placeholder="Enter your strategy description here...", label="Strategy Description"),
outputs=gr.JSON(label="Interpreted Strategy"),
title="Strategy Interpreter",
description="Convert trading strategy descriptions into structured JSON format."
)
iface_suggest = gr.Interface(
fn=suggest_strategy,
inputs=[
gr.Textbox(lines=1, placeholder="Enter risk level (e.g., medium)...", label="Risk Level"),
gr.Textbox(lines=1, placeholder="Enter market type (e.g., spot)...", label="Market Type")
], # Ensure this list is properly closed with ]
outputs=gr.JSON(label="Suggested Strategy"),
title="Strategy Suggester",
description="Generate a unique trading strategy based on risk level and market type."
)
app = gr.TabbedInterface([iface_interpret, iface_suggest], ["Interpret Strategy", "Suggest Strategy"])
if __name__ == "__main__":
app.launch()