sergiopaniego HF Staff commited on
Commit
2cb48ed
·
1 Parent(s): 66c19f0
Files changed (1) hide show
  1. app.py +8 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import subprocess
3
  import torch
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
 
6
  from vllm import LLM
7
  from sal.models.reward_models import RLHFFlow
8
 
@@ -17,6 +18,12 @@ if not os.path.exists("search-and-learn"):
17
  subprocess.run(["pip", "install", "-e", "./search-and-learn[dev]"])
18
 
19
 
 
 
 
 
 
 
20
  model_path = "meta-llama/Llama-3.2-1B-Instruct"
21
  prm_path = "RLHFlow/Llama3.1-8B-PRM-Deepseek-Data"
22
 
@@ -25,6 +32,7 @@ llm = LLM(
25
  gpu_memory_utilization=0.5, # Utilize 50% of GPU memory
26
  enable_prefix_caching=True, # Optimize repeated prefix computations
27
  seed=42, # Set seed for reproducibility
 
28
  )
29
 
30
 
 
3
  import torch
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
+ from vllm.config import DeviceConfig
7
  from vllm import LLM
8
  from sal.models.reward_models import RLHFFlow
9
 
 
18
  subprocess.run(["pip", "install", "-e", "./search-and-learn[dev]"])
19
 
20
 
21
+ device_config = DeviceConfig(device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
22
+ print('device_config', device_config)
23
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24
+ print('device', device)
25
+
26
+
27
  model_path = "meta-llama/Llama-3.2-1B-Instruct"
28
  prm_path = "RLHFlow/Llama3.1-8B-PRM-Deepseek-Data"
29
 
 
32
  gpu_memory_utilization=0.5, # Utilize 50% of GPU memory
33
  enable_prefix_caching=True, # Optimize repeated prefix computations
34
  seed=42, # Set seed for reproducibility
35
+ config=device_config
36
  )
37
 
38