MusIre commited on
Commit
8d2c3ff
·
verified ·
1 Parent(s): abc4ab8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -85,8 +85,17 @@ model_name = "EleutherAI/gpt-neo-1.3B"
85
  tokenizer = AutoTokenizer.from_pretrained(model_name)
86
  model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
87
 
 
 
 
 
 
 
 
 
88
  # Function to enrich prompt
89
  def enrich_prompt(artist, style):
 
90
  artist_info = dataset_desc.loc[dataset_desc['artists'].str.lower() == artist.lower(), 'description'].values
91
  style_info = style_desc.loc[style_desc['style'].str.lower() == style.lower(), 'description'].values
92
 
 
85
  tokenizer = AutoTokenizer.from_pretrained(model_name)
86
  model_gptneo = AutoModelForCausalLM.from_pretrained(model_name).to(device)
87
 
88
+
89
+ #Load dataset
90
+
91
+ dataset_desc = pd.read_csv("dataset_desc.csv", delimiter=';', usecols=['Artists', 'Style', 'Description'])
92
+ dataset_desc.columns = dataset_desc.columns.str.lower()
93
+ style_desc = pd.read_csv("style_desc.csv", delimiter=';')
94
+ style_desc.columns = style_desc.columns.str.lower()
95
+
96
  # Function to enrich prompt
97
  def enrich_prompt(artist, style):
98
+
99
  artist_info = dataset_desc.loc[dataset_desc['artists'].str.lower() == artist.lower(), 'description'].values
100
  style_info = style_desc.loc[style_desc['style'].str.lower() == style.lower(), 'description'].values
101