fix added
Browse files- SkinGPT.py +20 -21
- app.py +1 -0
SkinGPT.py
CHANGED
@@ -310,27 +310,26 @@ class SkinGPT4(nn.Module):
|
|
310 |
print(f"\n[DEBUG] After replacement:")
|
311 |
print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[0][1], :5]}...")
|
312 |
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
#
|
328 |
-
|
329 |
-
#
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
return None
|
334 |
|
335 |
class SkinGPTClassifier:
|
336 |
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
|
|
310 |
print(f"\n[DEBUG] After replacement:")
|
311 |
print(f"Image token embedding (after):\n{input_embeddings[0, replace_positions[0][1], :5]}...")
|
312 |
|
313 |
+
outputs = self.llama.generate(
|
314 |
+
inputs_embeds=input_embeddings,
|
315 |
+
max_new_tokens=max_new_tokens,
|
316 |
+
temperature=0.7,
|
317 |
+
top_k=40,
|
318 |
+
top_p=0.9,
|
319 |
+
repetition_penalty=1.1,
|
320 |
+
do_sample=True,
|
321 |
+
pad_token_id = self.tokenizer.eos_token_id,
|
322 |
+
eos_token_id = self.tokenizer.eos_token_id
|
323 |
+
)
|
324 |
+
|
325 |
+
|
326 |
+
full_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
327 |
+
# print(f"Full Output from llama : {full_output}")
|
328 |
+
response = full_output.split("### Response:")[-1].strip()
|
329 |
+
# print(f"Response from llama : {full_output}")
|
330 |
+
|
331 |
+
return response
|
332 |
+
|
|
|
333 |
|
334 |
class SkinGPTClassifier:
|
335 |
def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
|
app.py
CHANGED
@@ -49,6 +49,7 @@ warnings.filterwarnings("ignore")
|
|
49 |
|
50 |
@st.cache_resource
|
51 |
def get_classifier():
|
|
|
52 |
torch.use_deterministic_algorithms(True)
|
53 |
classifier = SkinGPTClassifier()
|
54 |
for module in [classifier.model.vit,
|
|
|
49 |
|
50 |
@st.cache_resource
|
51 |
def get_classifier():
|
52 |
+
st.cache_resource.clear()
|
53 |
torch.use_deterministic_algorithms(True)
|
54 |
classifier = SkinGPTClassifier()
|
55 |
for module in [classifier.model.vit,
|