Keltezaa commited on
Commit
e043aa2
·
verified ·
1 Parent(s): 41e3236

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -0
app.py CHANGED
@@ -14,6 +14,7 @@ import random
14
  import time
15
  import requests
16
  import pandas as pd
 
17
 
18
  # Disable tokenizer parallelism
19
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -56,6 +57,9 @@ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
56
  MAX_SEED = 2**32 - 1
57
 
58
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
 
 
 
59
 
60
  def process_input(input_text):
61
  # Tokenize and truncate input
@@ -69,6 +73,17 @@ def process_input(input_text):
69
  input_text = "Your long prompt goes here..."
70
  inputs = process_input(input_text)
71
 
 
 
 
 
 
 
 
 
 
 
 
72
  class calculateDuration:
73
  def __init__(self, activity_name=""):
74
  self.activity_name = activity_name
 
14
  import time
15
  import requests
16
  import pandas as pd
17
+ import torch.nn as nn
18
 
19
  # Disable tokenizer parallelism
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
57
  MAX_SEED = 2**32 - 1
58
 
59
  pipe.flux_pipe_call_that_returns_an_iterable_of_images = flux_pipe_call_that_returns_an_iterable_of_images.__get__(pipe)
60
+ # Replace text embedding generation with pre-computed Longformer embeddings
61
+ pipeline.text_encoder = None # Disable the default CLIP text encoder
62
+ pipeline.longformer_embeddings = transformed_embeddings # Inject your embeddings
63
 
64
  def process_input(input_text):
65
  # Tokenize and truncate input
 
73
  input_text = "Your long prompt goes here..."
74
  inputs = process_input(input_text)
75
 
76
+ # Get Longformer embeddings
77
+ with torch.no_grad():
78
+ longformer_embeddings = longformer_model(**encoded_input).last_hidden_state
79
+
80
+ # Create a transformation layer to match CLIP's embedding dimension
81
+ transform_layer = nn.Linear(pooled_embeddings.size(-1), 512)
82
+ transformed_embeddings = transform_layer(pooled_embeddings) # Shape: [batch_size, 512]
83
+
84
+ # Pass your embeddings to the pipeline during generation
85
+ image = pipeline(prompt=None, text_embeddings=pipeline.longformer_embeddings)
86
+
87
  class calculateDuration:
88
  def __init__(self, activity_name=""):
89
  self.activity_name = activity_name