Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import json | |
from jsonschema import validate, ValidationError | |
import logging | |
import torch | |
# Initialize logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("StrategyInterpreterSpace") | |
# Load model and tokenizer | |
model_name = "EleutherAI/gpt-neo-2.7B" # Using a smaller model to fit within memory constraints | |
logger.info(f"Loading model '{model_name}'...") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name) | |
logger.info("Model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Failed to load model: {e}") | |
raise e | |
# Define JSON schema | |
schema = { | |
"type": "object", | |
"required": [ | |
"strategy_name", | |
"market_type", | |
"assets", | |
"trade_parameters", | |
"conditions", | |
"risk_management" | |
], | |
"properties": { | |
"strategy_name": {"type": "string"}, | |
"market_type": {"type": "string", "enum": ["spot", "futures", "margin"]}, | |
"assets": {"type": "array", "items": {"type": "string"}}, | |
"trade_parameters": { | |
"type": "object", | |
"required": ["leverage", "order_type", "position_size"], | |
"properties": { | |
"leverage": {"type": "number"}, | |
"order_type": {"type": "string"}, | |
"position_size": {"type": "number"} | |
} | |
}, | |
"conditions": { | |
"type": "object", | |
"required": ["entry", "exit"], | |
"properties": { | |
"entry": { | |
"type": "array", | |
"items": {"$ref": "#/definitions/condition"} | |
}, | |
"exit": { | |
"type": "array", | |
"items": {"$ref": "#/definitions/condition"} | |
} | |
} | |
}, | |
"risk_management": { | |
"type": "object", | |
"required": ["stop_loss", "take_profit", "trailing_stop_loss"], | |
"properties": { | |
"stop_loss": {"type": "number"}, | |
"take_profit": {"type": "number"}, | |
"trailing_stop_loss": {"type": "number"} | |
} | |
} | |
}, | |
"definitions": { | |
"condition": { | |
"type": "object", | |
"required": ["indicator", "operator", "value", "timeframe"], | |
"properties": { | |
"indicator": {"type": "string"}, | |
"operator": {"type": "string", "enum": [">", "<", "==", ">=", "<="]}, | |
"value": {"type": ["string", "number"]}, | |
"timeframe": {"type": "string"}, | |
"indicator_parameters": { | |
"type": "object", | |
"properties": { | |
"period": {"type": "number"}, | |
}, | |
"additionalProperties": True | |
} | |
} | |
} | |
} | |
} | |
def interpret_strategy(description: str) -> dict: | |
logger.info("Received strategy description for interpretation.") | |
prompt = f""" | |
You are an expert crypto trading assistant. Convert the following trading strategy description into a JSON format strictly following this schema: | |
{json.dumps(schema, indent=2)} | |
Ensure that the response contains only valid JSON with the correct parameters. Do not include any additional text or explanations. | |
Strategy Description: | |
{description} | |
JSON: | |
""" | |
logger.debug(f"Prompt constructed: {prompt}") | |
try: | |
inputs = tokenizer.encode(prompt, return_tensors="pt") | |
logger.info("Tokenized the input prompt.") | |
except Exception as e: | |
logger.error(f"Error during tokenization: {e}") | |
return {"error": f"Error during tokenization: {e}"} | |
try: | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_length=800, # Reduced max_length to prevent overly long outputs | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
logger.info("Model generated output.") | |
except Exception as e: | |
logger.error(f"Error during model generation: {e}") | |
return {"error": f"Error during model generation: {e}"} | |
try: | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
logger.debug(f"Generated text: {generated_text}") | |
response_text = generated_text[len(prompt):].strip() | |
logger.debug(f"Response text after prompt removal: {response_text}") | |
except Exception as e: | |
logger.error(f"Error during decoding: {e}") | |
return {"error": f"Error during decoding: {e}"} | |
# Validate JSON | |
try: | |
strategy_data = json.loads(response_text) | |
validate(instance=strategy_data, schema=schema) | |
logger.info("Strategy interpreted successfully and validated against schema.") | |
return strategy_data # Return as dict for Gradio's JSON output | |
except json.JSONDecodeError as e: | |
logger.error(f"JSON decoding error: {e}") | |
# Return raw text for debugging | |
return {"error": f"Error interpreting strategy: Invalid JSON format.\nGenerated Text:\n{response_text}\nDetails: {e}"} | |
except ValidationError as e: | |
logger.error(f"JSON validation error: {e}") | |
return {"error": f"Error interpreting strategy: JSON does not conform to schema.\nDetails: {e}"} | |
def suggest_strategy(risk_level: str, market_type: str) -> dict: | |
logger.info("Received request to suggest a new strategy.") | |
prompt = f"""Please create a unique crypto trading strategy suitable for a '{risk_level}' risk appetite in the '{market_type}' market. | |
Ensure the JSON matches this schema: | |
{json.dumps(schema, indent=2)} | |
Use indicators and conditions that can be applied by ccxt, bitget, pandas-ta, and backtrader. | |
JSON:""" | |
logger.debug(f"Prompt constructed for strategy suggestion: {prompt}") | |
try: | |
inputs = tokenizer.encode(prompt, return_tensors="pt") | |
logger.info("Tokenized the suggestion prompt.") | |
except Exception as e: | |
logger.error(f"Error during tokenization: {e}") | |
return {"error": f"Error during tokenization: {e}"} | |
try: | |
with torch.no_grad(): | |
outputs = model.generate( | |
inputs, | |
max_length=800, # Reduced max_length to prevent overly long outputs | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
eos_token_id=tokenizer.eos_token_id, | |
) | |
logger.info("Model generated suggestion output.") | |
except Exception as e: | |
logger.error(f"Error during model generation: {e}") | |
return {"error": f"Error during model generation: {e}"} | |
try: | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
logger.debug(f"Generated suggestion text: {generated_text}") | |
response_text = generated_text[len(prompt):].strip() | |
logger.debug(f"Suggestion response text after prompt removal: {response_text}") | |
except Exception as e: | |
logger.error(f"Error during decoding: {e}") | |
return {"error": f"Error during decoding: {e}"} | |
# Validate JSON | |
try: | |
strategy_data = json.loads(response_text) | |
validate(instance=strategy_data, schema=schema) | |
if strategy_data.get("market_type") != market_type: | |
raise ValueError("The generated strategy's market type does not match the selected market type.") | |
logger.info("Strategy suggested successfully and validated against schema.") | |
return strategy_data # Return as dict for Gradio's JSON output | |
except json.JSONDecodeError as e: | |
logger.error(f"JSON decoding error: {e}") | |
# Return raw text for debugging | |
return {"error": f"Error generating strategy: Invalid JSON format.\nGenerated Text:\n{response_text}\nDetails: {e}"} | |
except ValidationError as e: | |
logger.error(f"JSON validation error: {e}") | |
return {"error": f"Error generating strategy: JSON does not conform to schema.\nDetails: {e}"} | |
except ValueError as e: | |
logger.error(f"Market type mismatch error: {e}") | |
return {"error": f"Error generating strategy: {e}"} | |
iface_interpret = gr.Interface( | |
fn=interpret_strategy, | |
inputs=gr.Textbox(lines=10, placeholder="Enter your strategy description here...", label="Strategy Description"), | |
outputs=gr.JSON(label="Interpreted Strategy"), | |
title="Strategy Interpreter", | |
description="Convert trading strategy descriptions into structured JSON format." | |
) | |
iface_suggest = gr.Interface( | |
fn=suggest_strategy, | |
inputs=[ | |
gr.Textbox(lines=1, placeholder="Enter risk level (e.g., medium)...", label="Risk Level"), | |
gr.Textbox(lines=1, placeholder="Enter market type (e.g., spot)...", label="Market Type") | |
], # Ensure this list is properly closed with ] | |
outputs=gr.JSON(label="Suggested Strategy"), | |
title="Strategy Suggester", | |
description="Generate a unique trading strategy based on risk level and market type." | |
) | |
app = gr.TabbedInterface([iface_interpret, iface_suggest], ["Interpret Strategy", "Suggest Strategy"]) | |
if __name__ == "__main__": | |
app.launch() | |