soufyane commited on
Commit
d749365
1 Parent(s): 08517cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -0
app.py CHANGED
@@ -4,6 +4,11 @@ import keras_nlp
4
 
5
  import numpy as np
6
  import pandas as pd
 
 
 
 
 
7
  keras.utils.set_random_seed(42)
8
 
9
  gemma_lm = keras_nlp.models.CausalLM.from_preset("hf://soufyane/gemma_2b_instruct_FT_DATA_SCIENCE_lora36_1")
 
4
 
5
  import numpy as np
6
  import pandas as pd
7
+
8
+ import os
9
+ os.environ["KERAS_BACKEND"] = "jax" # you can also use tensorflow or torch
10
+ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1" # avoid memory fragmentation on JAX backend.
11
+
12
  keras.utils.set_random_seed(42)
13
 
14
  gemma_lm = keras_nlp.models.CausalLM.from_preset("hf://soufyane/gemma_2b_instruct_FT_DATA_SCIENCE_lora36_1")