Yhhxhfh commited on
Commit
db2e73b
·
verified ·
1 Parent(s): 1f6cebc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -8
app.py CHANGED
@@ -10,7 +10,6 @@ from functools import lru_cache
10
  from dotenv import load_dotenv
11
  from fastapi import FastAPI, Request, HTTPException
12
  from fastapi.responses import JSONResponse
13
- import time
14
  from tqdm import tqdm
15
 
16
  urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@@ -49,10 +48,11 @@ class ModelManager:
49
 
50
  def load_all_models(self):
51
  with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
52
- futures = [executor.submit(self._load_model, config) for config in tqdm(global_data['model_configs'], desc="Cargando modelos")]
53
- for future in as_completed(futures):
54
  future.result()
55
 
 
56
  def _load_model(self, model_config):
57
  model_name = model_config['name']
58
  cache_file = os.path.join(self.model_cache_dir, f"{model_name}.pkl")
@@ -66,7 +66,7 @@ class ModelManager:
66
  with open(cache_file, "wb") as f:
67
  pickle.dump(self.models[model_name], f)
68
  except Exception as e:
69
- print(f"Error al cargar el modelo {model_name}: {e}")
70
  self.models[model_name] = None
71
 
72
  def get_model(self, model_name):
@@ -81,8 +81,8 @@ def normalize_input(input_text):
81
  return input_text.strip()
82
 
83
  def remove_duplicates(text):
84
- text = re.sub(r'(Hello there, how are you\? \[/INST\]){2,}', 'Hello there, how are you? [/INST]', text)
85
- text = re.sub(r'(How are you\? \[/INST\]){2,}', 'How are you? [/INST]', text)
86
  text = text.replace('[/INST]', '')
87
  lines = text.split('\n')
88
  unique_lines = []
@@ -107,7 +107,7 @@ async def process_message(message):
107
 
108
  with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
109
  futures = [executor.submit(generate_model_response, model_manager.get_model(config['name']), inputs) for config in global_data['model_configs'] if model_manager.get_model(config['name'])]
110
- for i, future in enumerate(as_completed(futures)):
111
  model_name = global_data['model_configs'][i]['name']
112
  responses[model_name] = future.result()
113
 
@@ -122,7 +122,7 @@ async def api_generate_multimodel(request: Request):
122
  data = await request.json()
123
  message = data.get("message")
124
  if not message:
125
- raise HTTPException(status_code=400, detail="Mensaje faltante")
126
  response = await process_message(message)
127
  return JSONResponse({"response": response})
128
  except HTTPException as e:
 
10
  from dotenv import load_dotenv
11
  from fastapi import FastAPI, Request, HTTPException
12
  from fastapi.responses import JSONResponse
 
13
  from tqdm import tqdm
14
 
15
  urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
 
48
 
49
  def load_all_models(self):
50
  with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
51
+ futures = [executor.submit(self._load_model, config) for config in tqdm(global_data['model_configs'], desc="Loading models")]
52
+ for future in tqdm(as_completed(futures), total=len(global_data['model_configs']), desc="Loading models complete"):
53
  future.result()
54
 
55
+
56
  def _load_model(self, model_config):
57
  model_name = model_config['name']
58
  cache_file = os.path.join(self.model_cache_dir, f"{model_name}.pkl")
 
66
  with open(cache_file, "wb") as f:
67
  pickle.dump(self.models[model_name], f)
68
  except Exception as e:
69
+ print(f"Error loading model {model_name}: {e}")
70
  self.models[model_name] = None
71
 
72
  def get_model(self, model_name):
 
81
  return input_text.strip()
82
 
83
  def remove_duplicates(text):
84
+ text = re.sub(r'(Hello there, how are you\? \[/INST\]){2,}', 'Hello there, how are you?', text)
85
+ text = re.sub(r'(How are you\? \[/INST\]){2,}', 'How are you?', text)
86
  text = text.replace('[/INST]', '')
87
  lines = text.split('\n')
88
  unique_lines = []
 
107
 
108
  with ThreadPoolExecutor(max_workers=len(global_data['model_configs'])) as executor:
109
  futures = [executor.submit(generate_model_response, model_manager.get_model(config['name']), inputs) for config in global_data['model_configs'] if model_manager.get_model(config['name'])]
110
+ for i, future in enumerate(tqdm(as_completed(futures), total=len([f for f in futures]), desc="Generating responses")):
111
  model_name = global_data['model_configs'][i]['name']
112
  responses[model_name] = future.result()
113
 
 
122
  data = await request.json()
123
  message = data.get("message")
124
  if not message:
125
+ raise HTTPException(status_code=400, detail="Missing message")
126
  response = await process_message(message)
127
  return JSONResponse({"response": response})
128
  except HTTPException as e: