Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -43,14 +43,16 @@ model_name = "unsloth/Llama-3.2-3B-Instruct"
|
|
| 43 |
# )
|
| 44 |
|
| 45 |
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
|
|
|
|
|
|
|
| 46 |
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 47 |
model_name,
|
| 48 |
trust_remote_code=True,
|
| 49 |
config=model_config,
|
| 50 |
-
|
| 51 |
-
device_map=device,
|
| 52 |
)
|
| 53 |
|
|
|
|
| 54 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 55 |
|
| 56 |
query_pipeline = transformers.pipeline(
|
|
@@ -58,14 +60,15 @@ query_pipeline = transformers.pipeline(
|
|
| 58 |
model=model,
|
| 59 |
tokenizer=tokenizer,
|
| 60 |
return_full_text=True,
|
| 61 |
-
torch_dtype=torch.float16,
|
| 62 |
-
device_map=device,
|
| 63 |
temperature=0.7,
|
| 64 |
top_p=0.9,
|
| 65 |
top_k=50,
|
| 66 |
-
max_new_tokens=256
|
| 67 |
)
|
| 68 |
|
|
|
|
| 69 |
llm = HuggingFacePipeline(pipeline=query_pipeline)
|
| 70 |
|
| 71 |
books_db_client_retriever = RetrievalQA.from_chain_type(
|
|
|
|
| 43 |
# )
|
| 44 |
|
| 45 |
model_config = transformers.AutoConfig.from_pretrained(model_name, max_new_tokens=1024)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 49 |
model_name,
|
| 50 |
trust_remote_code=True,
|
| 51 |
config=model_config,
|
| 52 |
+
device_map="auto" if device == "cuda" else None,
|
|
|
|
| 53 |
)
|
| 54 |
|
| 55 |
+
|
| 56 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 57 |
|
| 58 |
query_pipeline = transformers.pipeline(
|
|
|
|
| 60 |
model=model,
|
| 61 |
tokenizer=tokenizer,
|
| 62 |
return_full_text=True,
|
| 63 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 64 |
+
device_map="auto" if device == "cuda" else None,
|
| 65 |
temperature=0.7,
|
| 66 |
top_p=0.9,
|
| 67 |
top_k=50,
|
| 68 |
+
max_new_tokens=128 # Reduce this from 256
|
| 69 |
)
|
| 70 |
|
| 71 |
+
|
| 72 |
llm = HuggingFacePipeline(pipeline=query_pipeline)
|
| 73 |
|
| 74 |
books_db_client_retriever = RetrievalQA.from_chain_type(
|