KeerthiVM commited on
Commit
77a8dca
·
1 Parent(s): 2caff32
Files changed (2) hide show
  1. SkinGPT.py +20 -21
  2. app.py +1 -0
SkinGPT.py CHANGED
@@ -310,27 +310,26 @@ class SkinGPT4(nn.Module):
310
  print(f"\n[DEBUG] After replacement:")
311
  print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[0][1], :5]}...")
312
 
313
- # outputs = self.llama.generate(
314
- # inputs_embeds=input_embeddings,
315
- # max_new_tokens=max_new_tokens,
316
- # temperature=0.7,
317
- # top_k=40,
318
- # top_p=0.9,
319
- # repetition_penalty=1.1,
320
- # do_sample=True,
321
- # pad_token_id = self.tokenizer.eos_token_id,
322
- # eos_token_id = self.tokenizer.eos_token_id
323
- # )
324
- #
325
- #
326
- # full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
327
- # # print(f"Full Output from llama : {full_output}")
328
- # response = full_output.split("### Response:")[-1].strip()
329
- # # print(f"Response from llama : {full_output}")
330
- #
331
- # return response
332
-
333
- return None
334
 
335
  class SkinGPTClassifier:
336
  def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
 
310
  print(f"\n[DEBUG] After replacement:")
311
  print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[0][1], :5]}...")
312
 
313
+ outputs = self.llama.generate(
314
+ inputs_embeds=input_embeddings,
315
+ max_new_tokens=max_new_tokens,
316
+ temperature=0.7,
317
+ top_k=40,
318
+ top_p=0.9,
319
+ repetition_penalty=1.1,
320
+ do_sample=True,
321
+ pad_token_id = self.tokenizer.eos_token_id,
322
+ eos_token_id = self.tokenizer.eos_token_id
323
+ )
324
+
325
+
326
+ full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
327
+ # print(f"Full Output from llama : {full_output}")
328
+ response = full_output.split("### Response:")[-1].strip()
329
+ # print(f"Response from llama : {full_output}")
330
+
331
+ return response
332
+
 
333
 
334
  class SkinGPTClassifier:
335
  def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
app.py CHANGED
@@ -49,6 +49,7 @@ warnings.filterwarnings("ignore")
49
 
50
  @st.cache_resource
51
  def get_classifier():
 
52
  torch.use_deterministic_algorithms(True)
53
  classifier = SkinGPTClassifier()
54
  for module in [classifier.model.vit,
 
49
 
50
  @st.cache_resource
51
  def get_classifier():
52
+ st.cache_resource.clear()
53
  torch.use_deterministic_algorithms(True)
54
  classifier = SkinGPTClassifier()
55
  for module in [classifier.model.vit,