Keltezaa commited on
Commit
f36149d
·
verified ·
1 Parent(s): f27fc80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -38
app.py CHANGED
@@ -19,39 +19,6 @@ import pandas as pd
19
  # Disable tokenizer parallelism
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
 
22
- # Initialize the CLIP tokenizer and model
23
- clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16")
24
- clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
25
- clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
26
-
27
- # Initialize the Longformer tokenizer and model
28
- longformer_tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
29
- longformer_model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
30
-
31
- # Example usage
32
- input_text = "Your long prompt goes here..."
33
- inputs = preprocess_prompt(input_text)
34
-
35
- def preprocess_prompt(input_text, max_clip_tokens=77):
36
- """
37
- Preprocess the input prompt based on its length:
38
- - If the prompt is <= max_clip_tokens, summarize it.
39
- - If the prompt is > max_clip_tokens, split and process it.
40
- """
41
- # Tokenize the prompt to determine its token length
42
- tokens = clip_processor.tokenizer(input_text, return_tensors="pt")["input_ids"][0]
43
- token_count = len(tokens)
44
-
45
- if token_count <= max_clip_tokens:
46
- # Use summarization for shorter prompts
47
- print("Using summarization (Option 5) as the prompt is short.")
48
- return process_summarized_input(input_text)
49
- else:
50
- # Use split-and-process for longer prompts
51
- print("Using chunking (Option 3) as the prompt exceeds 77 tokens.")
52
- return process_clip_chunks(input_text)
53
-
54
-
55
  # Summarization Function (Option 5)
56
  def summarize_prompt(input_text, max_length=77):
57
  """
@@ -62,7 +29,6 @@ def summarize_prompt(input_text, max_length=77):
62
  print(f"Summarized prompt: {summarized_text}")
63
  return summarized_text
64
 
65
-
66
  def process_summarized_input(input_text):
67
  """
68
  Prepares summarized text for CLIP processing.
@@ -71,7 +37,6 @@ def process_summarized_input(input_text):
71
  inputs = clip_processor(text=summarized_text, return_tensors="pt", padding=True, truncation=True, max_length=77)
72
  return inputs
73
 
74
-
75
  def split_prompt_with_overlap(prompt, chunk_size=77, overlap=10):
76
  tokens = clip_processor.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
77
  chunks = [
@@ -79,9 +44,12 @@ def split_prompt_with_overlap(prompt, chunk_size=77, overlap=10):
79
  for i in range(0, len(tokens), chunk_size - overlap)
80
  ]
81
  return chunks
82
-
83
- chunks = split_prompt("Test " * 200)
84
- assert all(len(chunk) <= 77 for chunk in chunks), "Chunk size exceeded"
 
 
 
85
 
86
  def process_clip_chunks(input_text):
87
  """
@@ -96,6 +64,38 @@ def process_clip_chunks(input_text):
96
  processed_chunks.append(inputs)
97
  return processed_chunks # Return processed chunks for downstream usage
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # Load prompts for randomization
100
  df = pd.read_csv('prompts.csv', header=None)
101
  prompt_values = df.values.flatten()
 
19
  # Disable tokenizer parallelism
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Summarization Function (Option 5)
23
  def summarize_prompt(input_text, max_length=77):
24
  """
 
29
  print(f"Summarized prompt: {summarized_text}")
30
  return summarized_text
31
 
 
32
  def process_summarized_input(input_text):
33
  """
34
  Prepares summarized text for CLIP processing.
 
37
  inputs = clip_processor(text=summarized_text, return_tensors="pt", padding=True, truncation=True, max_length=77)
38
  return inputs
39
 
 
40
  def split_prompt_with_overlap(prompt, chunk_size=77, overlap=10):
41
  tokens = clip_processor.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
42
  chunks = [
 
44
  for i in range(0, len(tokens), chunk_size - overlap)
45
  ]
46
  return chunks
47
+
48
+ def split_prompt(prompt, chunk_size=77):
49
+ """Splits a long prompt into chunks of the specified token size."""
50
+ tokens = clip_processor.tokenizer(prompt, return_tensors="pt")["input_ids"][0]
51
+ chunks = [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)]
52
+ return chunks
53
 
54
  def process_clip_chunks(input_text):
55
  """
 
64
  processed_chunks.append(inputs)
65
  return processed_chunks # Return processed chunks for downstream usage
66
 
67
+ def preprocess_prompt(input_text, max_clip_tokens=77):
68
+ """
69
+ Preprocess the input prompt based on its length:
70
+ - If the prompt is <= max_clip_tokens, summarize it.
71
+ - If the prompt is > max_clip_tokens, split and process it.
72
+ """
73
+ # Tokenize the prompt to determine its token length
74
+ tokens = clip_processor.tokenizer(input_text, return_tensors="pt")["input_ids"][0]
75
+ token_count = len(tokens)
76
+
77
+ if token_count <= max_clip_tokens:
78
+ # Use summarization for shorter prompts
79
+ print("Using summarization (Option 5) as the prompt is short.")
80
+ return process_summarized_input(input_text)
81
+ else:
82
+ # Use split-and-process for longer prompts
83
+ print("Using chunking (Option 3) as the prompt exceeds 77 tokens.")
84
+ return process_clip_chunks(input_text)
85
+
86
+ # Initialize the CLIP tokenizer and model
87
+ clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch16")
88
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
89
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
90
+
91
+ # Initialize the Longformer tokenizer and model
92
+ longformer_tokenizer = LongformerTokenizer.from_pretrained("allenai/longformer-base-4096")
93
+ longformer_model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
94
+
95
+ # Example usage
96
+ input_text = "Your long prompt goes here..."
97
+ inputs = preprocess_prompt(input_text)
98
+
99
  # Load prompts for randomization
100
  df = pd.read_csv('prompts.csv', header=None)
101
  prompt_values = df.values.flatten()