mohdelgaar commited on
Commit
7fb1029
·
1 Parent(s): 0860d85

upgrade gradio

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +18 -3
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🔁
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 4.40.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: indigo
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.9.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -149,13 +149,28 @@ def generate_random(sent1, count, approx):
149
  if sent1 == '':
150
  raise gr.Error('Please input a source text.')
151
  preds, interpolations = [], []
 
 
152
  for c in range(count):
153
  idx = np.random.randint(0, len(ling_collection))
154
  ling_ex = ling_collection[idx]
155
  shared_state.target = ling_ex.copy()
156
- pred, interpolation = generate_with_feedback(sent1, approx)
 
 
 
 
 
 
 
 
 
 
 
 
157
  preds.append(pred)
158
  interpolations.append(interpolation)
 
159
  return '\n***\n'.join(preds), '\n***\n'.join(interpolations)
160
 
161
  def estimate_gen(sent1, sent2, approx):
@@ -172,8 +187,8 @@ def estimate_gen(sent1, sent2, approx):
172
  ling_pred = round_ling(ling_pred)
173
  shared_state.target = ling_pred.copy()
174
 
175
- gen = generate_with_feedback(sent1, approx)
176
- return gen[0], gen[1], [gr.update(value=val) for val in shared_state.target]
177
 
178
  def estimate_tgt(sent2, ling_dict, approx):
179
  if 'approximate' in approx:
 
149
  if sent1 == '':
150
  raise gr.Error('Please input a source text.')
151
  preds, interpolations = [], []
152
+ orig_active_indices = shared_state.active_indices
153
+ shared_state.active_indices = set(range(len(lng_names)))
154
  for c in range(count):
155
  idx = np.random.randint(0, len(ling_collection))
156
  ling_ex = ling_collection[idx]
157
  shared_state.target = ling_ex.copy()
158
+ success = False
159
+ patience = 0
160
+ while not success:
161
+ pred, interpolation = generate_with_feedback(sent1, approx)[:2]
162
+ if pred not in preds:
163
+ success = True
164
+ elif patience < 3:
165
+ add_to_target()
166
+ patience += 1
167
+ else:
168
+ idx = np.random.randint(0, len(ling_collection))
169
+ ling_ex = ling_collection[idx]
170
+ shared_state.target = ling_ex.copy()
171
  preds.append(pred)
172
  interpolations.append(interpolation)
173
+ shared_state.active_indices = orig_active_indices
174
  return '\n***\n'.join(preds), '\n***\n'.join(interpolations)
175
 
176
  def estimate_gen(sent1, sent2, approx):
 
187
  ling_pred = round_ling(ling_pred)
188
  shared_state.target = ling_pred.copy()
189
 
190
+ gen = generate_with_feedback(sent1, approx)[:2]
191
+ return gen + [gr.update(value=val) for val in shared_state.target]
192
 
193
  def estimate_tgt(sent2, ling_dict, approx):
194
  if 'approximate' in approx: