tmmdev commited on
Commit
21ac398
·
verified ·
1 Parent(s): 77a41a7

Update pattern_analyzer.py

Browse files
Files changed (1) hide show
  1. pattern_analyzer.py +17 -13
pattern_analyzer.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  os.environ['HF_HOME'] = '/tmp/huggingface'
 
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import torch
5
  import json
@@ -12,20 +13,28 @@ class PatternAnalyzer:
12
  "device_map": "auto",
13
  "torch_dtype": torch.float32,
14
  "low_cpu_mem_usage": True,
15
- "max_memory": {"cpu": "8GB"},
16
- "offload_folder": "/tmp/offload"
 
 
 
 
 
17
  }
18
-
19
  self.model = AutoModelForCausalLM.from_pretrained(
20
  "tmmdev/codellama-pattern-analysis",
21
- **model_kwargs
 
 
 
22
  )
23
-
24
  self.tokenizer = AutoTokenizer.from_pretrained(
25
  "tmmdev/codellama-pattern-analysis",
26
  use_fast=True
27
  )
28
-
29
  self.basic_patterns = {
30
  'channel': {'min_points': 4, 'confidence_threshold': 0.7},
31
  'triangle': {'min_points': 3, 'confidence_threshold': 0.75},
@@ -34,6 +43,7 @@ class PatternAnalyzer:
34
  'double_top': {'max_deviation': 0.02, 'confidence_threshold': 0.85},
35
  'double_bottom': {'max_deviation': 0.02, 'confidence_threshold': 0.85}
36
  }
 
37
  self.pattern_logic = PatternLogic()
38
 
39
  def analyze_data(self, ohlcv_data):
@@ -44,16 +54,13 @@ class PatternAnalyzer:
44
  2. Triangle: Must have clear convergence point
45
  3. Support: Minimum 3 price bounces
46
  4. Resistance: Minimum 3 price rejections
47
-
48
  INPUT DATA:
49
  {ohlcv_data.to_json(orient='records')}
50
-
51
  Return ONLY high-confidence patterns (>0.8) in JSON format with exact price coordinates."""
52
 
53
  inputs = self.tokenizer(data_prompt, return_tensors="pt")
54
  outputs = self.model.generate(**inputs, max_length=1000)
55
  analysis = self.tokenizer.decode(outputs[0])
56
-
57
  return self.parse_analysis(analysis)
58
 
59
  def parse_analysis(self, analysis_text):
@@ -61,10 +68,9 @@ class PatternAnalyzer:
61
  json_start = analysis_text.find('{')
62
  json_end = analysis_text.rfind('}') + 1
63
  json_str = analysis_text[json_start:json_end]
64
-
65
  analysis_data = json.loads(json_str)
66
  patterns = []
67
-
68
  for pattern in analysis_data.get('patterns', []):
69
  pattern_type = pattern.get('type')
70
  if pattern_type in self.basic_patterns:
@@ -79,8 +85,6 @@ class PatternAnalyzer:
79
  'timestamp': pd.Timestamp.now().isoformat()
80
  }
81
  })
82
-
83
  return patterns
84
-
85
  except json.JSONDecodeError:
86
  return []
 
1
  import os
2
  os.environ['HF_HOME'] = '/tmp/huggingface'
3
+
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
6
  import json
 
13
  "device_map": "auto",
14
  "torch_dtype": torch.float32,
15
  "low_cpu_mem_usage": True,
16
+ "max_memory": {
17
+ "cpu": "4GB",
18
+ "disk": "8GB"
19
+ },
20
+ "offload_folder": "/tmp/offload",
21
+ "load_in_8bit": True,
22
+ "revision": "main"
23
  }
24
+
25
  self.model = AutoModelForCausalLM.from_pretrained(
26
  "tmmdev/codellama-pattern-analysis",
27
+ **model_kwargs,
28
+ use_safetensors=True,
29
+ trust_remote_code=True,
30
+ resume_download=True
31
  )
32
+
33
  self.tokenizer = AutoTokenizer.from_pretrained(
34
  "tmmdev/codellama-pattern-analysis",
35
  use_fast=True
36
  )
37
+
38
  self.basic_patterns = {
39
  'channel': {'min_points': 4, 'confidence_threshold': 0.7},
40
  'triangle': {'min_points': 3, 'confidence_threshold': 0.75},
 
43
  'double_top': {'max_deviation': 0.02, 'confidence_threshold': 0.85},
44
  'double_bottom': {'max_deviation': 0.02, 'confidence_threshold': 0.85}
45
  }
46
+
47
  self.pattern_logic = PatternLogic()
48
 
49
  def analyze_data(self, ohlcv_data):
 
54
  2. Triangle: Must have clear convergence point
55
  3. Support: Minimum 3 price bounces
56
  4. Resistance: Minimum 3 price rejections
 
57
  INPUT DATA:
58
  {ohlcv_data.to_json(orient='records')}
 
59
  Return ONLY high-confidence patterns (>0.8) in JSON format with exact price coordinates."""
60
 
61
  inputs = self.tokenizer(data_prompt, return_tensors="pt")
62
  outputs = self.model.generate(**inputs, max_length=1000)
63
  analysis = self.tokenizer.decode(outputs[0])
 
64
  return self.parse_analysis(analysis)
65
 
66
  def parse_analysis(self, analysis_text):
 
68
  json_start = analysis_text.find('{')
69
  json_end = analysis_text.rfind('}') + 1
70
  json_str = analysis_text[json_start:json_end]
 
71
  analysis_data = json.loads(json_str)
72
  patterns = []
73
+
74
  for pattern in analysis_data.get('patterns', []):
75
  pattern_type = pattern.get('type')
76
  if pattern_type in self.basic_patterns:
 
85
  'timestamp': pd.Timestamp.now().isoformat()
86
  }
87
  })
 
88
  return patterns
 
89
  except json.JSONDecodeError:
90
  return []