# app.py import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import json from jsonschema import validate, ValidationError import logging import torch # Initialize logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger("StrategyInterpreterSpace") # Load model and tokenizer model_name = "EleutherAI/gpt-neo-2.7B" # Using a smaller model to fit within memory constraints logger.info(f"Loading model '{model_name}'...") try: tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) logger.info("Model loaded successfully.") except Exception as e: logger.error(f"Failed to load model: {e}") raise e # Define JSON schema schema = { "type": "object", "required": [ "strategy_name", "market_type", "assets", "trade_parameters", "conditions", "risk_management" ], "properties": { "strategy_name": {"type": "string"}, "market_type": {"type": "string", "enum": ["spot", "futures", "margin"]}, "assets": {"type": "array", "items": {"type": "string"}}, "trade_parameters": { "type": "object", "required": ["leverage", "order_type", "position_size"], "properties": { "leverage": {"type": "number"}, "order_type": {"type": "string"}, "position_size": {"type": "number"} } }, "conditions": { "type": "object", "required": ["entry", "exit"], "properties": { "entry": { "type": "array", "items": {"$ref": "#/definitions/condition"} }, "exit": { "type": "array", "items": {"$ref": "#/definitions/condition"} } } }, "risk_management": { "type": "object", "required": ["stop_loss", "take_profit", "trailing_stop_loss"], "properties": { "stop_loss": {"type": "number"}, "take_profit": {"type": "number"}, "trailing_stop_loss": {"type": "number"} } } }, "definitions": { "condition": { "type": "object", "required": ["indicator", "operator", "value", "timeframe"], "properties": { "indicator": {"type": "string"}, "operator": {"type": "string", "enum": [">", "<", "==", ">=", "<="]}, "value": {"type": ["string", "number"]}, "timeframe": {"type": "string"}, "indicator_parameters": { "type": "object", "properties": { "period": {"type": "number"}, }, "additionalProperties": True } } } } } def interpret_strategy(description: str) -> dict: logger.info("Received strategy description for interpretation.") prompt = f""" You are an expert crypto trading assistant. Convert the following trading strategy description into a JSON format strictly following this schema: {json.dumps(schema, indent=2)} Ensure that the response contains only valid JSON with the correct parameters. Do not include any additional text or explanations. Strategy Description: {description} JSON: """ logger.debug(f"Prompt constructed: {prompt}") try: inputs = tokenizer.encode(prompt, return_tensors="pt") logger.info("Tokenized the input prompt.") except Exception as e: logger.error(f"Error during tokenization: {e}") return {"error": f"Error during tokenization: {e}"} try: with torch.no_grad(): outputs = model.generate( inputs, max_length=800, # Reduced max_length to prevent overly long outputs temperature=0.7, top_p=0.9, do_sample=True, eos_token_id=tokenizer.eos_token_id, ) logger.info("Model generated output.") except Exception as e: logger.error(f"Error during model generation: {e}") return {"error": f"Error during model generation: {e}"} try: generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) logger.debug(f"Generated text: {generated_text}") response_text = generated_text[len(prompt):].strip() logger.debug(f"Response text after prompt removal: {response_text}") except Exception as e: logger.error(f"Error during decoding: {e}") return {"error": f"Error during decoding: {e}"} # Validate JSON try: strategy_data = json.loads(response_text) validate(instance=strategy_data, schema=schema) logger.info("Strategy interpreted successfully and validated against schema.") return strategy_data # Return as dict for Gradio's JSON output except json.JSONDecodeError as e: logger.error(f"JSON decoding error: {e}") # Return raw text for debugging return {"error": f"Error interpreting strategy: Invalid JSON format.\nGenerated Text:\n{response_text}\nDetails: {e}"} except ValidationError as e: logger.error(f"JSON validation error: {e}") return {"error": f"Error interpreting strategy: JSON does not conform to schema.\nDetails: {e}"} def suggest_strategy(risk_level: str, market_type: str) -> dict: logger.info("Received request to suggest a new strategy.") prompt = f"""Please create a unique crypto trading strategy suitable for a '{risk_level}' risk appetite in the '{market_type}' market. Ensure the JSON matches this schema: {json.dumps(schema, indent=2)} Use indicators and conditions that can be applied by ccxt, bitget, pandas-ta, and backtrader. JSON:""" logger.debug(f"Prompt constructed for strategy suggestion: {prompt}") try: inputs = tokenizer.encode(prompt, return_tensors="pt") logger.info("Tokenized the suggestion prompt.") except Exception as e: logger.error(f"Error during tokenization: {e}") return {"error": f"Error during tokenization: {e}"} try: with torch.no_grad(): outputs = model.generate( inputs, max_length=800, # Reduced max_length to prevent overly long outputs temperature=0.7, top_p=0.9, do_sample=True, eos_token_id=tokenizer.eos_token_id, ) logger.info("Model generated suggestion output.") except Exception as e: logger.error(f"Error during model generation: {e}") return {"error": f"Error during model generation: {e}"} try: generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) logger.debug(f"Generated suggestion text: {generated_text}") response_text = generated_text[len(prompt):].strip() logger.debug(f"Suggestion response text after prompt removal: {response_text}") except Exception as e: logger.error(f"Error during decoding: {e}") return {"error": f"Error during decoding: {e}"} # Validate JSON try: strategy_data = json.loads(response_text) validate(instance=strategy_data, schema=schema) if strategy_data.get("market_type") != market_type: raise ValueError("The generated strategy's market type does not match the selected market type.") logger.info("Strategy suggested successfully and validated against schema.") return strategy_data # Return as dict for Gradio's JSON output except json.JSONDecodeError as e: logger.error(f"JSON decoding error: {e}") # Return raw text for debugging return {"error": f"Error generating strategy: Invalid JSON format.\nGenerated Text:\n{response_text}\nDetails: {e}"} except ValidationError as e: logger.error(f"JSON validation error: {e}") return {"error": f"Error generating strategy: JSON does not conform to schema.\nDetails: {e}"} except ValueError as e: logger.error(f"Market type mismatch error: {e}") return {"error": f"Error generating strategy: {e}"} iface_interpret = gr.Interface( fn=interpret_strategy, inputs=gr.Textbox(lines=10, placeholder="Enter your strategy description here...", label="Strategy Description"), outputs=gr.JSON(label="Interpreted Strategy"), title="Strategy Interpreter", description="Convert trading strategy descriptions into structured JSON format." ) iface_suggest = gr.Interface( fn=suggest_strategy, inputs=[ gr.Textbox(lines=1, placeholder="Enter risk level (e.g., medium)...", label="Risk Level"), gr.Textbox(lines=1, placeholder="Enter market type (e.g., spot)...", label="Market Type") ], # Ensure this list is properly closed with ] outputs=gr.JSON(label="Suggested Strategy"), title="Strategy Suggester", description="Generate a unique trading strategy based on risk level and market type." ) app = gr.TabbedInterface([iface_interpret, iface_suggest], ["Interpret Strategy", "Suggest Strategy"]) if __name__ == "__main__": app.launch()