minko186 commited on
Commit
e1b0f65
·
verified ·
1 Parent(s): d96b873

Update ai_generate.py

Browse files
Files changed (1) hide show
  1. ai_generate.py +43 -11
ai_generate.py CHANGED
@@ -1,19 +1,51 @@
1
  import torch
2
  from openai import OpenAI
3
  import os
 
4
 
5
- model = "gpt-3.5-turbo"
6
  client = OpenAI(
7
  api_key=os.environ.get("openai_key"),
8
  )
9
 
10
- def generate(text):
11
- message=[{"role": "user", "content": text}]
12
- response = client.chat.completions.create(
13
- model=model,
14
- messages = message,
15
- temperature=0.2,
16
- max_tokens=800,
17
- frequency_penalty=0.0
18
- )
19
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from openai import OpenAI
3
  import os
4
+ from transformers import pipeline
5
 
 
6
  client = OpenAI(
7
  api_key=os.environ.get("openai_key"),
8
  )
9
 
10
+ pipes = {
11
+ 'GPT-Neo': pipeline("text-generation", model="EleutherAI/gpt-neo-2.7B"),
12
+ 'Llama 3': pipeline("text-generation", model="meta-llama/Meta-Llama-3-8B")
13
+ }
14
+
15
+ def generate(text, model):
16
+ if model is "GPT-Neo":
17
+ response = pipes[model](text)
18
+ return response[0]
19
+ elif model is "Llama 3":
20
+ response = pipes[model](text)
21
+ return response[0]
22
+ elif model is "OpenAI GPT 3.5":
23
+ message=[{"role": "user", "content": text}]
24
+ response = client.chat.completions.create(
25
+ model="gpt-3.5-turbo",
26
+ messages = message,
27
+ temperature=0.2,
28
+ max_tokens=800,
29
+ frequency_penalty=0.0
30
+ )
31
+ return response[0].message.content
32
+ elif model is "OpenAI GPT 4":
33
+ message=[{"role": "user", "content": text}]
34
+ response = client.chat.completions.create(
35
+ model="gpt-4-turbo",
36
+ messages = message,
37
+ temperature=0.2,
38
+ max_tokens=800,
39
+ frequency_penalty=0.0
40
+ )
41
+ return response[0].message.content
42
+ elif model is "OpenAI GPT 4o":
43
+ message=[{"role": "user", "content": text}]
44
+ response = client.chat.completions.create(
45
+ model="gpt-4o",
46
+ messages = message,
47
+ temperature=0.2,
48
+ max_tokens=800,
49
+ frequency_penalty=0.0
50
+ )
51
+ return response[0].message.content