Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
|
8 |
st.set_page_config(page_title="AI Study Assistant", page_icon="🤖", layout="wide")
|
9 |
|
10 |
# Set up the Groq API Key
|
11 |
-
GROQ_API_KEY = "
|
12 |
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
|
13 |
|
14 |
# Initialize the Groq client
|
@@ -25,11 +25,13 @@ MODEL_NAME = "deepseek-ai/DeepSeek-R1"
|
|
25 |
|
26 |
try:
|
27 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
|
|
28 |
model = AutoModelForCausalLM.from_pretrained(
|
29 |
MODEL_NAME,
|
30 |
trust_remote_code=True,
|
31 |
-
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
32 |
-
device_map="auto" if torch.cuda.is_available() else None
|
|
|
33 |
)
|
34 |
|
35 |
def generate_response_hf(user_message):
|
|
|
8 |
st.set_page_config(page_title="AI Study Assistant", page_icon="🤖", layout="wide")
|
9 |
|
10 |
# Set up the Groq API Key
|
11 |
+
GROQ_API_KEY = "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" # Replace with your actual key
|
12 |
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
|
13 |
|
14 |
# Initialize the Groq client
|
|
|
25 |
|
26 |
try:
|
27 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
28 |
+
|
29 |
model = AutoModelForCausalLM.from_pretrained(
|
30 |
MODEL_NAME,
|
31 |
trust_remote_code=True,
|
32 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, # ✅ Use FP16 on GPU, FP32 on CPU
|
33 |
+
device_map="auto" if torch.cuda.is_available() else None, # ✅ Enable auto GPU usage
|
34 |
+
quantization_config=None # ✅ Disable unsupported FP8 quantization
|
35 |
)
|
36 |
|
37 |
def generate_response_hf(user_message):
|