OpenAI SDK开发(1)

本次完成的是基本框架的搭建,项目结构如下图所示:

avatar

common

Constants

common包下定义了Constants类,里面暂时写了一个枚举对象Role,是要用在Message中的一个参数,而Message在Request和Response中都有,所以放在common包下,后面会用到.

avatar

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
public class Constants {

/**
* 官网支持的请求角色类型;system、user、assistant
* https://platform.openai.com/docs/guides/chat/introduction
*/
public enum Role {

SYSTEM("system"),
USER("user"),
ASSISTANT("assistant"),
;

private String code;

Role(String code) {
this.code = code;
}

public String getCode() {
return code;
}

}

}

domain

chat(聊天模型)

ChatChoice
1
2
3
4
5
6
7
8
9
10
@Data
public class ChatChoice implements Serializable {

private long index;
@JsonProperty("message")
private Message message;
@JsonProperty("finish_reason")
private String finishReason;

}

这里面定义的是choices中的几个参数,choices参数是在Response中的

avatar

ChatCompletionRequest(聊天完成请求)

把model单独写了一个枚举类,定义类所需参数

这些参数里只有model和message是必须的,其他的都是可选的,所以用Builder模式来构建对象

avatar

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@Data
@Builder
@Slf4j
@JsonInclude(JsonInclude.Include.NON_NULL)
@NoArgsConstructor
@AllArgsConstructor
public class ChatCompletionRequest implements Serializable {

/** 默认模型 */
private String model = Model.GPT_3_5_TURBO.getCode();
/** 问题描述 */
private List<Message> messages;
/** 控制温度【随机性】;0到2之间。较高的值(如0.8)将使输出更加随机,而较低的值(如0.2)将使输出更加集中和确定 */
private double temperature = 0.2;
/** 多样性控制;使用温度采样的替代方法称为核心采样,其中模型考虑具有top_p概率质量的令牌的结果。因此,0.1 意味着只考虑包含前 10% 概率质量的代币 */
@JsonProperty("top_p")
private Double topP = 1d;
/** 为每个提示生成的完成次数 */
private Integer n = 1;
/** 是否为流式输出;就是一蹦一蹦的,出来结果 */
private boolean stream = false;
/** 停止输出标识 */
private List<String> stop;
/** 输出字符串限制;0 ~ 4096 */
@JsonProperty("max_tokens")
private Integer maxTokens = 2048;
/** 频率惩罚;降低模型重复同一行的可能性 */
@JsonProperty("frequency_penalty")
private double frequencyPenalty = 0;
/** 存在惩罚;增强模型谈论新话题的可能性 */
@JsonProperty("presence_penalty")
private double presencePenalty = 0;
/** 生成多个调用结果,只显示最佳的。这样会更多的消耗你的 api token */
@JsonProperty("logit_bias")
private Map logitBias;
/** 调用标识,避免重复调用 */
private String user;

@Getter
@AllArgsConstructor
public enum Model {
/** gpt-3.5-turbo */
GPT_3_5_TURBO("gpt-3.5-turbo"),
/** GPT4.0 */
GPT_4("gpt-4"),
/** GPT4.0 超长上下文 */
GPT_4_32K("gpt-4-32k"),
;
private String code;
}

}
ChatCompletionResponse(聊天完成响应)

定义了Response中的参数

avatar

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@Data
public class ChatCompletionResponse implements Serializable {

/** ID */
private String id;
/** 对象 */
private String object;
/** 模型 */
private String model;
/** 对话 */
private List<ChatChoice> choices;
/** 创建 */
private long created;
/** 耗材 */
private Usage usage;

}
Message

定义的聊天消息对象,包含消息角色、消息内容、消息名称

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
@Data
@JsonInclude(JsonInclude.Include.NON_NULL)
public class Message implements Serializable {

private String role;
private String content;
private String name;

public Message() {
}

private Message(Builder builder) {
this.role = builder.role;
this.content = builder.content;
this.name = builder.name;
}

public static Builder builder() {
return new Builder();
}

/**
* 建造者模式
*/
public static final class Builder {

private String role;
private String content;
private String name;

public Builder() {
}

public Builder role(Constants.Role role) {
this.role = role.getCode();
return this;
}

public Builder content(String content) {
this.content = content;
return this;
}

public Builder name(String name) {
this.name = name;
return this;
}

public Message build() {
return new Message(this);
}
}

}

other

Usage(使用量)

是Response中的一个参数,记录了token的使用量

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
public class Usage implements Serializable {

/** 提示令牌 */
@JsonProperty("prompt_tokens")
private long promptTokens;
/** 完成令牌 */
@JsonProperty("completion_tokens")
private long completionTokens;
/** 总量令牌 */
@JsonProperty("total_tokens")
private long totalTokens;

public long getPromptTokens() {
return promptTokens;
}

public void setPromptTokens(long promptTokens) {
this.promptTokens = promptTokens;
}

public long getCompletionTokens() {
return completionTokens;
}

public void setCompletionTokens(long completionTokens) {
this.completionTokens = completionTokens;
}

public long getTotalTokens() {
return totalTokens;
}

public void setTotalTokens(long totalTokens) {
this.totalTokens = totalTokens;
}

}

OpenAiResponse
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@Data
public class OpenAiResponse<T> implements Serializable {

private String object;
private List<T> data;
private Error error;


@Data
public class Error {
private String message;
private String type;
private String param;
private String code;
}

}

qa(问答模型)

avatar
很快就不能用了,而且跟聊天模型差不多,就不贴代码了

QAChoice

avatar

QACompletionRequest

model和prompt参数必要,其他参数可选

avatar

QACompletionResponse

avatar

interceptor

OpenAiInterceptor(自定义拦截器)

auth方法将token参数加入url对象,返回一个新的请求,intercept对该请求进行预处理,然后将处理后的请求传递给下一个拦截器(或目标方法)继续处理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
public class OpenAiInterceptor implements Interceptor {

/** OpenAi apiKey 需要在官网申请 */
private String apiKey;
/** 访问授权接口的认证 Token */
private String authToken;

public OpenAiInterceptor(String apiKey, String authToken) {
this.apiKey = apiKey;
this.authToken = authToken;
}

@NotNull
@Override
public Response intercept(Chain chain) throws IOException {
return chain.proceed(this.auth(apiKey, chain.request()));
}

private Request auth(String apiKey, Request original) {
// 设置Token信息;如果没有此类限制,是不需要设置的。
HttpUrl url = original.url().newBuilder()
.addQueryParameter("token", authToken)
.build();

// 创建请求
return original.newBuilder()
.url(url)
.header(Header.AUTHORIZATION.getValue(), "Bearer " + apiKey)
.header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue())
.method(original.method(), original.body())
.build();
}

}

session

IOpenAiApi

定义访问接口,传入请求

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
public interface IOpenAiApi {


/**
* @Body注解用于描述一个接口中的一个方法参数,该参数将接收请求体中的表单数据。
* 当处理HTTP POST请求时,通常将请求体中的表单数据映射到接口中的方法参数。
**/

/**
* 文本问答
* @param qaCompletionRequest 请求信息
* @return 返回结果
*/
@POST("v1/completions")
Single<QACompletionResponse> completions(@Body QACompletionRequest qaCompletionRequest);

/**
* 默认 GPT-3.5 问答模型
* @param chatCompletionRequest 请求信息
* @return 返回结果
*/
@POST("v1/chat/completions")
Single<ChatCompletionResponse> completions(@Body ChatCompletionRequest chatCompletionRequest);

}
OpenAiSession

会话接口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public interface OpenAiSession {

/**
* 文本问答
* @param qaCompletionRequest 请求信息
* @return 返回结果
*/
QACompletionResponse completions(QACompletionRequest qaCompletionRequest);

/**
* 文本问答;简单请求
* @param question 请求信息
* @return 返回结果
*/
QACompletionResponse completions(String question);

/**
* 默认 GPT-3.5 问答模型
* @param chatCompletionRequest 请求信息
* @return 返回结果
*/
ChatCompletionResponse completions(ChatCompletionRequest chatCompletionRequest);

}
OpenAiSessionFactory

会话工厂接口

1
2
3
4
5
public interface OpenAiSessionFactory {

OpenAiSession openSession();

}
Configuration

配置类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@Slf4j
@Data
@NoArgsConstructor
@AllArgsConstructor
public class Configuration {

@Getter
@NotNull
private String apiKey;

@Getter
private String apiHost;

@Getter
// @NotNull
private String authToken;

}
DefaultOpenAiSession

实现OpenAiSession接口

blockingGet()是RxJava中Single中的方法,用于将当前线程阻塞,这里的作用是将异步计算的结果转换为同步结果,使得调用这个方法的线程会等待异步计算完成后才继续执行

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
public class DefaultOpenAiSession implements OpenAiSession {

private IOpenAiApi openAiApi;

public DefaultOpenAiSession(IOpenAiApi openAiApi) {
this.openAiApi = openAiApi;
}

@Override
public QACompletionResponse completions(QACompletionRequest qaCompletionRequest) {
return this.openAiApi.completions(qaCompletionRequest).blockingGet();
}

@Override
public QACompletionResponse completions(String question) {
QACompletionRequest request = QACompletionRequest
.builder()
.prompt(question)
.build();
Single<QACompletionResponse> completions = this.openAiApi.completions(request);
return completions.blockingGet();
}

@Override
public ChatCompletionResponse completions(ChatCompletionRequest chatCompletionRequest) {
return this.openAiApi.completions(chatCompletionRequest).blockingGet();
}

}
DefaultOpenAiSessionFactory

实现OpenAiSessionFactory接口,其实实现的是IOpenAiApi接口,返回一个DefaultOpenAiSession(openAiApi)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
public class DefaultOpenAiSessionFactory implements OpenAiSessionFactory {

private final Configuration configuration;

public DefaultOpenAiSessionFactory(Configuration configuration) {
this.configuration = configuration;
}

@Override
public OpenAiSession openSession() {
// 1. 日志配置
HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor();
httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);

// 2. 开启 Http 客户端
OkHttpClient okHttpClient = new OkHttpClient
.Builder()
.addInterceptor(httpLoggingInterceptor)
.addInterceptor(new OpenAiInterceptor(configuration.getApiKey(), configuration.getAuthToken()))
.connectTimeout(450, TimeUnit.SECONDS)
.writeTimeout(450, TimeUnit.SECONDS)
.readTimeout(450, TimeUnit.SECONDS)
.build();

// 3. 实现IOpenAiApi接口,创建 API 服务,即网络请求接口对象实例
IOpenAiApi openAiApi = new Retrofit.Builder()
.baseUrl(configuration.getApiHost())//得到url
.client(okHttpClient)//设置客户端
// RxJava2CallAdapterFactory的主要作用是:
// 1.将Android的Call对象转换为RxJava的Observable类型。
// 2.处理Call的错误和结果,并将其转换为RxJava的onError和onNext事件。
// 3.添加适当的错误处理逻辑,例如重试、网络错误等。
.addCallAdapterFactory(RxJava2CallAdapterFactory.create())
.addConverterFactory(JacksonConverterFactory.create())//在请求和响应中使用jackson库进行json转换
.build().create(IOpenAiApi.class);

return new DefaultOpenAiSession(openAiApi);
}

}

单元测试

传url,key,token和request就行

avatar

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@Slf4j
public class ApiTest {

private OpenAiSession openAiSession;

@Before
public void test_OpenAiSessionFactory() {
// 1. 配置文件
Configuration configuration = new Configuration();
configuration.setApiHost("https://api.openai-proxy.com/");
configuration.setApiKey("xxx");
configuration.setAuthToken("xxx");
// 2. 会话工厂
OpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
// 3. 开启会话
this.openAiSession = factory.openSession();
}

/**
* 此对话模型 3.5 接近于官网体验
*/
@Test
public void test_chat_completions() {
// 1. 创建参数
ChatCompletionRequest chatCompletion = ChatCompletionRequest
.builder()
.messages(Collections.singletonList(Message.builder().role(Constants.Role.USER).content("写一个java冒泡排序").build()))
.model(ChatCompletionRequest.Model.GPT_3_5_TURBO.getCode())
.build();
// 2. 发起请求
ChatCompletionResponse chatCompletionResponse = openAiSession.completions(chatCompletion);
// 3. 解析结果
chatCompletionResponse.getChoices().forEach(e -> {
log.info("测试结果:{}", e.getMessage());
});
}

}