OpenAI SDK开发(1)
本次完成的是基本框架的搭建,项目结构如下图所示:

common
Constants
common包下定义了Constants类,里面暂时写了一个枚举对象Role,是要用在Message中的一个参数,而Message在Request和Response中都有,所以放在common包下,后面会用到.

| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 
 | public class Constants {
 
 
 
 
 public enum Role {
 
 SYSTEM("system"),
 USER("user"),
 ASSISTANT("assistant"),
 ;
 
 private String code;
 
 Role(String code) {
 this.code = code;
 }
 
 public String getCode() {
 return code;
 }
 
 }
 
 }
 
 
 | 
domain
chat(聊天模型)
ChatChoice
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 
 | @Datapublic class ChatChoice implements Serializable {
 
 private long index;
 @JsonProperty("message")
 private Message message;
 @JsonProperty("finish_reason")
 private String finishReason;
 
 }
 
 | 
这里面定义的是choices中的几个参数,choices参数是在Response中的

ChatCompletionRequest(聊天完成请求)
把model单独写了一个枚举类,定义类所需参数
这些参数里只有model和message是必须的,其他的都是可选的,所以用Builder模式来构建对象

| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 
 | @Data@Builder
 @Slf4j
 @JsonInclude(JsonInclude.Include.NON_NULL)
 @NoArgsConstructor
 @AllArgsConstructor
 public class ChatCompletionRequest implements Serializable {
 
 
 private String model = Model.GPT_3_5_TURBO.getCode();
 
 private List<Message> messages;
 
 private double temperature = 0.2;
 
 @JsonProperty("top_p")
 private Double topP = 1d;
 
 private Integer n = 1;
 
 private boolean stream = false;
 
 private List<String> stop;
 
 @JsonProperty("max_tokens")
 private Integer maxTokens = 2048;
 
 @JsonProperty("frequency_penalty")
 private double frequencyPenalty = 0;
 
 @JsonProperty("presence_penalty")
 private double presencePenalty = 0;
 
 @JsonProperty("logit_bias")
 private Map logitBias;
 
 private String user;
 
 @Getter
 @AllArgsConstructor
 public enum Model {
 
 GPT_3_5_TURBO("gpt-3.5-turbo"),
 
 GPT_4("gpt-4"),
 
 GPT_4_32K("gpt-4-32k"),
 ;
 private String code;
 }
 
 }
 
 | 
ChatCompletionResponse(聊天完成响应)
定义了Response中的参数

| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 
 | @Datapublic class ChatCompletionResponse implements Serializable {
 
 
 private String id;
 
 private String object;
 
 private String model;
 
 private List<ChatChoice> choices;
 
 private long created;
 
 private Usage usage;
 
 }
 
 | 
Message
定义的聊天消息对象,包含消息角色、消息内容、消息名称
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 
 | @Data@JsonInclude(JsonInclude.Include.NON_NULL)
 public class Message implements Serializable {
 
 private String role;
 private String content;
 private String name;
 
 public Message() {
 }
 
 private Message(Builder builder) {
 this.role = builder.role;
 this.content = builder.content;
 this.name = builder.name;
 }
 
 public static Builder builder() {
 return new Builder();
 }
 
 
 
 
 public static final class Builder {
 
 private String role;
 private String content;
 private String name;
 
 public Builder() {
 }
 
 public Builder role(Constants.Role role) {
 this.role = role.getCode();
 return this;
 }
 
 public Builder content(String content) {
 this.content = content;
 return this;
 }
 
 public Builder name(String name) {
 this.name = name;
 return this;
 }
 
 public Message build() {
 return new Message(this);
 }
 }
 
 }
 
 | 
other
Usage(使用量)
是Response中的一个参数,记录了token的使用量
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 
 | public class Usage implements Serializable {
 
 @JsonProperty("prompt_tokens")
 private long promptTokens;
 
 @JsonProperty("completion_tokens")
 private long completionTokens;
 
 @JsonProperty("total_tokens")
 private long totalTokens;
 
 public long getPromptTokens() {
 return promptTokens;
 }
 
 public void setPromptTokens(long promptTokens) {
 this.promptTokens = promptTokens;
 }
 
 public long getCompletionTokens() {
 return completionTokens;
 }
 
 public void setCompletionTokens(long completionTokens) {
 this.completionTokens = completionTokens;
 }
 
 public long getTotalTokens() {
 return totalTokens;
 }
 
 public void setTotalTokens(long totalTokens) {
 this.totalTokens = totalTokens;
 }
 
 }
 
 
 | 
OpenAiResponse
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 
 | @Datapublic class OpenAiResponse<T> implements Serializable {
 
 private String object;
 private List<T> data;
 private Error error;
 
 
 @Data
 public class Error {
 private String message;
 private String type;
 private String param;
 private String code;
 }
 
 }
 
 | 
qa(问答模型)

很快就不能用了,而且跟聊天模型差不多,就不贴代码了
QAChoice

QACompletionRequest
model和prompt参数必要,其他参数可选

QACompletionResponse

interceptor
OpenAiInterceptor(自定义拦截器)
auth方法将token参数加入url对象,返回一个新的请求,intercept对该请求进行预处理,然后将处理后的请求传递给下一个拦截器(或目标方法)继续处理
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 
 | public class OpenAiInterceptor implements Interceptor {
 
 private String apiKey;
 
 private String authToken;
 
 public OpenAiInterceptor(String apiKey, String authToken) {
 this.apiKey = apiKey;
 this.authToken = authToken;
 }
 
 @NotNull
 @Override
 public Response intercept(Chain chain) throws IOException {
 return chain.proceed(this.auth(apiKey, chain.request()));
 }
 
 private Request auth(String apiKey, Request original) {
 
 HttpUrl url = original.url().newBuilder()
 .addQueryParameter("token", authToken)
 .build();
 
 
 return original.newBuilder()
 .url(url)
 .header(Header.AUTHORIZATION.getValue(), "Bearer " + apiKey)
 .header(Header.CONTENT_TYPE.getValue(), ContentType.JSON.getValue())
 .method(original.method(), original.body())
 .build();
 }
 
 }
 
 | 
session
IOpenAiApi
定义访问接口,传入请求
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 
 | public interface IOpenAiApi {
 
 
 
 
 
 
 
 
 
 
 
 @POST("v1/completions")
 Single<QACompletionResponse> completions(@Body QACompletionRequest qaCompletionRequest);
 
 
 
 
 
 
 @POST("v1/chat/completions")
 Single<ChatCompletionResponse> completions(@Body ChatCompletionRequest chatCompletionRequest);
 
 }
 
 | 
OpenAiSession
会话接口
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 
 | public interface OpenAiSession {
 
 
 
 
 
 QACompletionResponse completions(QACompletionRequest qaCompletionRequest);
 
 
 
 
 
 
 QACompletionResponse completions(String question);
 
 
 
 
 
 
 ChatCompletionResponse completions(ChatCompletionRequest chatCompletionRequest);
 
 }
 
 | 
OpenAiSessionFactory
会话工厂接口
| 12
 3
 4
 5
 
 | public interface OpenAiSessionFactory {
 OpenAiSession openSession();
 
 }
 
 | 
Configuration
配置类
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 
 | @Slf4j@Data
 @NoArgsConstructor
 @AllArgsConstructor
 public class Configuration {
 
 @Getter
 @NotNull
 private String apiKey;
 
 @Getter
 private String apiHost;
 
 @Getter
 
 private String authToken;
 
 }
 
 | 
DefaultOpenAiSession
实现OpenAiSession接口
blockingGet()是RxJava中Single中的方法,用于将当前线程阻塞,这里的作用是将异步计算的结果转换为同步结果,使得调用这个方法的线程会等待异步计算完成后才继续执行
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 
 | public class DefaultOpenAiSession implements OpenAiSession {
 private IOpenAiApi openAiApi;
 
 public DefaultOpenAiSession(IOpenAiApi openAiApi) {
 this.openAiApi = openAiApi;
 }
 
 @Override
 public QACompletionResponse completions(QACompletionRequest qaCompletionRequest) {
 return this.openAiApi.completions(qaCompletionRequest).blockingGet();
 }
 
 @Override
 public QACompletionResponse completions(String question) {
 QACompletionRequest request = QACompletionRequest
 .builder()
 .prompt(question)
 .build();
 Single<QACompletionResponse> completions = this.openAiApi.completions(request);
 return completions.blockingGet();
 }
 
 @Override
 public ChatCompletionResponse completions(ChatCompletionRequest chatCompletionRequest) {
 return this.openAiApi.completions(chatCompletionRequest).blockingGet();
 }
 
 }
 
 | 
DefaultOpenAiSessionFactory
实现OpenAiSessionFactory接口,其实实现的是IOpenAiApi接口,返回一个DefaultOpenAiSession(openAiApi)
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 
 | public class DefaultOpenAiSessionFactory implements OpenAiSessionFactory {
 private final Configuration configuration;
 
 public DefaultOpenAiSessionFactory(Configuration configuration) {
 this.configuration = configuration;
 }
 
 @Override
 public OpenAiSession openSession() {
 
 HttpLoggingInterceptor httpLoggingInterceptor = new HttpLoggingInterceptor();
 httpLoggingInterceptor.setLevel(HttpLoggingInterceptor.Level.HEADERS);
 
 
 OkHttpClient okHttpClient = new OkHttpClient
 .Builder()
 .addInterceptor(httpLoggingInterceptor)
 .addInterceptor(new OpenAiInterceptor(configuration.getApiKey(), configuration.getAuthToken()))
 .connectTimeout(450, TimeUnit.SECONDS)
 .writeTimeout(450, TimeUnit.SECONDS)
 .readTimeout(450, TimeUnit.SECONDS)
 .build();
 
 
 IOpenAiApi openAiApi = new Retrofit.Builder()
 .baseUrl(configuration.getApiHost())
 .client(okHttpClient)
 
 
 
 
 .addCallAdapterFactory(RxJava2CallAdapterFactory.create())
 .addConverterFactory(JacksonConverterFactory.create())
 .build().create(IOpenAiApi.class);
 
 return new DefaultOpenAiSession(openAiApi);
 }
 
 }
 
 | 
单元测试
传url,key,token和request就行

| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 
 | @Slf4jpublic class ApiTest {
 
 private OpenAiSession openAiSession;
 
 @Before
 public void test_OpenAiSessionFactory() {
 
 Configuration configuration = new Configuration();
 configuration.setApiHost("https://api.openai-proxy.com/");
 configuration.setApiKey("xxx");
 configuration.setAuthToken("xxx");
 
 OpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
 
 this.openAiSession = factory.openSession();
 }
 
 
 
 
 @Test
 public void test_chat_completions() {
 
 ChatCompletionRequest chatCompletion = ChatCompletionRequest
 .builder()
 .messages(Collections.singletonList(Message.builder().role(Constants.Role.USER).content("写一个java冒泡排序").build()))
 .model(ChatCompletionRequest.Model.GPT_3_5_TURBO.getCode())
 .build();
 
 ChatCompletionResponse chatCompletionResponse = openAiSession.completions(chatCompletion);
 
 chatCompletionResponse.getChoices().forEach(e -> {
 log.info("测试结果:{}", e.getMessage());
 });
 }
 
 }
 
 |