zeroMN commited on
Commit
0b739f4
·
verified ·
1 Parent(s): 3f23984

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -1
app.py CHANGED
@@ -1,3 +1,92 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
2
 
3
- gr.load("models/zeroMN/SHMT").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ import random
6
  import gradio as gr
7
+ from transformers import (
8
+ BartForConditionalGeneration,
9
+ AutoModelForCausalLM,
10
+ BertModel,
11
+ Wav2Vec2Model,
12
+ CLIPModel,
13
+ AutoTokenizer
14
+ )
15
 
16
+ class MultiModalModel(nn.Module):
17
+ def __init__(self):
18
+ super(MultiModalModel, self).__init__()
19
+ # 初始化子模型
20
+ self.text_generator = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
21
+ self.code_generator = AutoModelForCausalLM.from_pretrained('gpt2')
22
+ self.nlp_encoder = BertModel.from_pretrained('bert-base-uncased')
23
+ self.speech_encoder = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')
24
+ self.vision_encoder = CLIPModel.from_pretrained('openai/clip-vit-base-patch32')
25
+
26
+ # 初始化分词器和处理器
27
+ self.text_tokenizer = AutoTokenizer.from_pretrained('facebook/bart-base')
28
+ self.code_tokenizer = AutoTokenizer.from_pretrained('gpt2')
29
+ self.nlp_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
30
+ self.speech_processor = AutoTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
31
+ self.vision_processor = AutoTokenizer.from_pretrained('openai/clip-vit-base-patch32')
32
+
33
+ def forward(self, task, inputs):
34
+ if task == 'text_generation':
35
+ attention_mask = inputs.get('attention_mask')
36
+ print("输入数据:", inputs)
37
+ outputs = self.text_generator.generate(
38
+ inputs['input_ids'],
39
+ max_new_tokens=100,
40
+ pad_token_id=self.text_tokenizer.eos_token_id,
41
+ attention_mask=attention_mask,
42
+ top_p=0.9,
43
+ top_k=50,
44
+ temperature=0.8,
45
+ do_sample=True
46
+ )
47
+ print("生成的输出:", outputs)
48
+ return self.text_tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+ elif task == 'code_generation':
50
+ attention_mask = inputs.get('attention_mask')
51
+ outputs = self.code_generator.generate(
52
+ inputs['input_ids'],
53
+ max_new_tokens=50,
54
+ pad_token_id=self.code_tokenizer.eos_token_id,
55
+ attention_mask=attention_mask,
56
+ top_p=0.95,
57
+ top_k=50,
58
+ temperature=1.2,
59
+ do_sample=True
60
+ )
61
+ return self.code_tokenizer.decode(outputs[0], skip_special_tokens=True)
62
+ # 添加其他任务的逻辑...
63
+
64
+ # 定义 Gradio 接口的推理函数
65
+ def gradio_inference(task, input_text):
66
+ if task == "text_generation":
67
+ tokenizer = model.text_tokenizer
68
+ elif task == "code_generation":
69
+ tokenizer = model.code_tokenizer
70
+ # 根据任务选择合适的分词器
71
+
72
+ inputs = tokenizer(input_text, return_tensors='pt')
73
+ inputs['attention_mask'] = torch.ones_like(inputs['input_ids'])
74
+
75
+ with torch.no_grad():
76
+ result = model(task, inputs)
77
+ return result
78
+
79
+ # 初始化模型
80
+ model = MultiModalModel()
81
+
82
+ # 创建 Gradio 接口
83
+ interface = gr.Interface(
84
+ fn=gradio_inference,
85
+ inputs=[gr.inputs.Dropdown(choices=["text_generation", "code_generation"], label="任务类型"), gr.inputs.Textbox(lines=2, placeholder="输入文本...")],
86
+ outputs="text",
87
+ title="多模态模型推理",
88
+ description="选择任务类型并输入文本以进行推理"
89
+ )
90
+
91
+ # 启动 Gradio 应用
92
+ interface.launch()